From f2a6851ad3e6e8499c57838db598375c0d043422 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Mon, 1 Jul 2019 12:50:39 -0700 Subject: [PATCH 01/26] [Relay][Pass] Only allow Module -> Module for opts managed by pass infra (#3430) * [Relay][Pass] Only allow Module -> Module for opts managed by pass infra * revert gradient pass --- include/tvm/relay/pass.h | 86 +----------- include/tvm/relay/transform.h | 27 ++++ python/tvm/relay/ir_pass.py | 96 -------------- python/tvm/relay/transform.py | 40 +++++- src/relay/pass/dead_code.cc | 3 - src/relay/pass/partial_eval.cc | 31 +++-- src/relay/pass/pass_manager.cc | 15 +++ src/relay/pass/to_a_normal_form.cc | 84 +++++------- src/relay/pass/to_graph_normal_form.cc | 5 +- .../relay/test_pass_dead_code_elimination.py | 60 ++++++--- tests/python/relay/test_pass_partial_eval.py | 125 +++++++++++------- .../relay/test_pass_to_a_normal_form.py | 40 ++++-- .../relay/test_pass_to_graph_normal_form.py | 15 +-- 13 files changed, 278 insertions(+), 349 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 294d22b812a13..79172c3743167 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -140,23 +140,6 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); */ TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2); -/*! - * \brief Add abstraction over a function - * - * For example: `square` is transformed to - * `fun x -> square x`. - * - * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion - * for more details. - * - * \param e The original function. - * \param mod The module used for referencing global functions, can be - * None. - * - * \return the new function with abstraction - */ -TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod); - /*! * \brief Check that each Var is only bound once. * @@ -288,24 +271,6 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); -/*! \brief Remove expressions which does not effect the program result. - * - * It will remove let bindings which are not referenced, - * and inline let bindings that are only used once. - * - * For example, this pass should turn `let a = 1 in 2` into `2`, - * as the value of the expression does not depend on a. - * - * As another example, `let a = 1 in a` will be optimized into 1, - * if the flag is turned on. - * - * \param e the expression to optimize. - * \param inline_once whether or not to inline binding used one. - * - * \return the optimized expression. - */ -TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false); - /*! * \brief Fold constant expressions. * @@ -387,38 +352,6 @@ TVM_DLL Map CollectDeviceInfo(const Expr& expr); */ TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); -/*! - * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). - * - * It will turn an expression that is in a graph form (with sharing implicit), - * to an expression with explicit sharing (A-Normal Form). - * - * The scope of the root expression is the global scope. - * - * The scope of any non root expression is the least common ancestor of all it's scope. - * - * Values are ordered by post-DFS order in each scope. - * - * \param e the expression to observably share. - * \param mod The module used for referencing global functions, can be - * None. - * - * \return expression in A-Normal Form. - */ -TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); - -/*! - * \brief Remove let binding and directly share via pointer instead. - * - * It will remove all let binding, - * and turn all of the variable bound by let into direct pointer reference. - * - * \param e the expression. - * - * \return the expression in graph normal form. - */ -TVM_DLL Expr ToGraphNormalForm(const Expr& e); - /*! * \brief Finds cases that the given match expression does not catch, if any. * @@ -432,28 +365,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e); TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); /*! - * \brief Aggressive constant propagation/constant folding/inlining. - * It will do as much computation in compile time as possible. - * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). - * As a side effect, code size will explode. - * - * \param e the expression - * \param mod the module - * - * \return the optimized expression. - */ -TVM_DLL Expr PartialEval(const Expr& e, const Module& mod); - -/* - * \brief Bind function parameters or free variables. + * \brief Bind the free variables to a Relay expression. * * Parameter binding can only happen if expr is a Function. * binds cannot change internal arguments of internal functions. * * \param expr The function to be binded. * \param binds The map of arguments to + * + * \return The expression with all free vars bound. */ -TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& bind_map); +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 04b4e64dc9c3b..9ae71d824f94e 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -541,6 +541,33 @@ TVM_DLL Pass AlterOpLayout(); */ TVM_DLL Pass CanonicalizeCast(); +/*! + * \brief Add abstraction over a function + * + * For example: `square` is transformed to + * `fun x -> square x`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion + * for more details. + * + * \return The pass. + */ +TVM_DLL Pass EtaExpand(); + +/*! + * \brief This is a helper function that runs a some optimization passes on + * a certain expression and returns the optimized version. With the help of this + * function, users don't need to manually construct a module, then perform + * passes, and finally and extract the target function/expression from the + * returned module frequently. + * + * \param expr The expression to be optimized. + * \param passes The passses that will be applied on the given expression. + * + * \return The optimized expression. + */ +TVM_DLL Expr OptimizeOnExpr(const Expr& expr, const Array& passes); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 1748571cb3163..52dc34d7aac9d 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -84,23 +84,6 @@ def backward_fold_scale_axis(expr): """ return _ir_pass.backward_fold_scale_axis(expr) -def eta_expand(expr, mod): - """Add abstraction over a function. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression, we expect that expr's types - should be fully inferred by infer_type. - mod : tvm.relay.Module - The global module. - - Returns - ------- - expanded_expr : tvm.relay.Expr - The expression after eta expansion. - """ - return _ir_pass.eta_expand(expr, mod) def forward_fold_scale_axis(expr): """Fold the scaling of axis into weights of conv2d/dense. @@ -318,25 +301,6 @@ def canonicalize_ops(expr): return _ir_pass.canonicalize_ops(expr) -def dead_code_elimination(expr, inline_once=False): - """ Remove expressions which does not effect the program result (dead code). - - Parameters - ---------- - expr : tvm.relay.Expr - The input Expression - - inline_once : Optional[Bool] - Whether to inline binding that occur only once. - Returns - ------- - result : tvm.relay.Expr - An expression which is semantically equal to the input expression, - but with dead code removed. - """ - return _ir_pass.dead_code_elimination(expr, inline_once) - - def alpha_equal(lhs, rhs): """Compare two Relay expr for structural equivalence (alpha equivalence). @@ -534,46 +498,6 @@ def collect_device_annotation_ops(expr): return _ir_pass.CollectDeviceAnnotationOps(expr) -def to_a_normal_form(expr, mod=None): - """ - Turn Graph Normal Form expression into A Normal Form Expression. - - The scope of the root expression is the global scope. - - The scope of any non root expression is the least common ancestor of all it's scope. - - Values are ordered by post-DFS order in each scope. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - mod : Optional[tvm.relay.Module] - The global module. - - Returns - ------- - result : tvm.relay.Expr - The output expression. - """ - return _ir_pass.to_a_normal_form(expr, mod) - - -def to_graph_normal_form(expr): - """Turn A Normal Form expression into Graph Normal Form expression - Parameters - ---------- - expr : tvm.relay.Expr - The input expression - Returns - ------- - result : tvm.relay.Expr - The output expression - """ - return _ir_pass.to_graph_normal_form(expr) - - def gradient(expr, mod=None, mode='higher_order'): """ Transform the input function, @@ -642,26 +566,6 @@ def eliminate_common_subexpr(expr, fskip=None): return _ir_pass.eliminate_common_subexpr(expr, fskip) -def partial_evaluate(expr, mod=None): - """ - Evaluate the static fragment of the code. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - mod : Optional[tvm.relay.Module] - The global module - - Returns - ------- - result : tvm.relay.Expr - The output expression. - """ - return _ir_pass.partial_evaluate(expr, mod) - - def unmatched_cases(match, mod=None): """ Finds cases that the match expression does not catch, if any. diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 5f47e5b446aa7..ba4857dc4d36e 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -302,15 +302,20 @@ def CanonicalizeOps(): return _transform.CanonicalizeOps() -def DeadCodeElimination(): - """ Remove expressions which does not effect the program result (dead code). +def DeadCodeElimination(inline_once=False): + """Remove expressions which does not effect the program result (dead code). + + Parameters + ---------- + inline_once: Optional[Bool] + Whether to inline binding that occurs only once. Returns ------- ret: tvm.relay.Pass The registered pass that eliminates the dead code in a Relay program. """ - return _transform.DeadCodeElimination() + return _transform.DeadCodeElimination(inline_once) def FoldConstant(): @@ -406,6 +411,7 @@ def ToANormalForm(): """ return _transform.ToANormalForm() + def EtaExpand(): """Add abstraction over a function @@ -416,6 +422,7 @@ def EtaExpand(): """ return _transform.EtaExpand() + def ToGraphNormalForm(): """Turn A Normal Form expression into Graph Normal Form expression @@ -449,7 +456,7 @@ def PartialEvaluate(): Returns ------- - ret : tvm.relay.Pass + ret: tvm.relay.Pass The registered pass that performs partial evaluation on an expression. """ return _transform.PartialEvaluate() @@ -465,6 +472,31 @@ def CanonicalizeCast(): """ return _transform.CanonicalizeCast() + +def OptimizeOnExpr(expr, passes): + """Perform optimization passes on an expressioin. + + Parameters + ---------- + expr: tvm.relay.Expr + The expression for optimization. + + passes: Union[Pass, List[Pass]] + The list of optimizations to be applied. + + Returns + ------- + ret: tvm.relay.Expr + The optimized expression. + """ + if isinstance(passes, Pass): + passes = [passes] + if not isinstance(passes, (list, tuple)): + raise TypeError("passes must be a pass or a list of pass objects.") + + return _transform.OptimizeOnExpr(expr, passes) + + def _wrap_class_module_pass(pass_cls, pass_info): """Wrap a python class as function pass""" class PyModulePass(ModulePass): diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 7e186f80df929..8799bf403375e 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -156,9 +156,6 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) { return CalcDep::Eliminate(e, inline_once); } -TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") -.set_body_typed(DeadCodeElimination); - namespace transform { Pass DeadCodeElimination(bool inline_once) { diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index b95c5844f8a40..e7edbb3153d85 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -1086,27 +1086,30 @@ Expr PostProcess(const Expr& e) { } // namespace partial_eval -Expr PartialEval(const Expr& e, const Module& m) { - return TransformF([&](const Expr& e) { +Module PartialEval(const Module& m) { + CHECK(m->entry_func.defined()); + auto func = m->Lookup(m->entry_func); + Expr ret = + TransformF([&](const Expr& e) { return LetList::With([&](LetList* ll) { - relay::partial_eval::PartialEvaluator pe(FreeVars(e), m); - pe.InitializeFuncId(e); - return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic); - }); - }, e); + relay::partial_eval::PartialEvaluator pe(FreeVars(e), m); + pe.InitializeFuncId(e); + return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic); + }); + }, func); + CHECK(ret->is_type()); + m->Update(m->entry_func, Downcast(ret)); + return m; } -TVM_REGISTER_API("relay._ir_pass.partial_evaluate") -.set_body_typed(PartialEval); - namespace transform { Pass PartialEval() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(PartialEval(f, m)); + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return PartialEval(m); }; - return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {}); + return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } TVM_REGISTER_API("relay._transform.PartialEvaluate") diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index d63d9121fe27e..a620316035c7e 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -573,6 +573,18 @@ class PassContext::Internal { } }; +Expr OptimizeOnExpr(const Expr& expr, const Array& passes) { + auto mod = ModuleNode::FromExpr(expr); + Sequential seq(passes); + auto pass_ctx = PassContext::Create(); + pass_ctx->opt_level = 3; + tvm::With ctx_scope(pass_ctx); + mod = seq(mod); + CHECK(mod.defined()); + auto entry_func = mod->Lookup(mod->entry_func); + return expr.as() == nullptr ? entry_func->body : entry_func; +} + TVM_REGISTER_API("relay._transform.GetCurrentPassContext") .set_body_typed(PassContext::Current); @@ -582,6 +594,9 @@ TVM_REGISTER_API("relay._transform.EnterPassContext") TVM_REGISTER_API("relay._transform.ExitPassContext") .set_body_typed(PassContext::Internal::ExitScope); +TVM_REGISTER_API("relay._transform.OptimizeOnExpr") +.set_body_typed(OptimizeOnExpr); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 324eddd21c5ca..b5a3f8552d8da 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -26,6 +26,8 @@ */ #include #include +#include +#include #include #include "let_list.h" #include "../../common/arena.h" @@ -35,10 +37,6 @@ namespace tvm { namespace relay { -Expr ToANormalForm(const Expr& e, - const Module& m, - std::unordered_set* gv); - struct ScopeNode; using Scope = std::shared_ptr; @@ -104,29 +102,21 @@ bool IsPrimitiveFunction(const Expr& e) { class Fill : ExprFunctor { public: static Expr ToANormalForm(const Expr& e, - const Module& m, const DependencyGraph& dg, - std::unordered_map* node_scope, - std::unordered_set* gv) { - Fill fi(m, dg, node_scope, gv); + std::unordered_map* node_scope) { + Fill fi(dg, node_scope); return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); } private: - Module mod_; const DependencyGraph& dg_; std::unordered_map* node_scope_; - std::unordered_set* visited_; std::unordered_map memo; - Fill(Module mod, - const DependencyGraph& dg, - std::unordered_map* node_scope, - std::unordered_set* visited) : - mod_(mod), + Fill(const DependencyGraph& dg, + std::unordered_map* node_scope) : dg_(dg), - node_scope_(node_scope), - visited_(visited) { } + node_scope_(node_scope) { } Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); @@ -246,10 +236,6 @@ class Fill : ExprFunctor { Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { GlobalVar gv = GetRef(gvn); - if (visited_->count(gv) == 0) { - visited_->insert(gv); - mod_->Update(gv, Downcast(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_))); - } return Atomic(gv, gv, v); } @@ -276,9 +262,7 @@ class Fill : ExprFunctor { } }; -Expr ToANormalFormAux(const Expr& e, - const Module& m, - std::unordered_set* gv) { +Expr ToANormalFormAux(const Expr& e) { /* When you lift a lambda, what is inside is also being lift. * * So we must determine the scope of the lambda before determining the scope of it's body. @@ -301,46 +285,40 @@ Expr ToANormalFormAux(const Expr& e, * We do an additional pass to fill all the LetList and we are done. */ std::unordered_map node_scope = CalcScope(dg); - return Fill::ToANormalForm(e, m, dg, &node_scope, gv); + return Fill::ToANormalForm(e, dg, &node_scope); } -Expr ToANormalForm(const Expr& e, - const Module& m, - std::unordered_set* gv) { - DLOG(INFO) - << "ToANF:" << std::endl - << AsText(e, false); - - Expr ret = - TransformF([&](const Expr& e) { - return ToANormalFormAux(e, m, gv); - }, e); - - CHECK_EQ(FreeVars(ret).size(), 0); +Module ToANormalForm(const Module& m) { + DLOG(INFO) << "ToANF:" << std::endl << m; + + tvm::Map updates; + auto funcs = m->functions; + for (const auto& it : funcs) { + Expr ret = + TransformF([&](const Expr& e) { + return ToANormalFormAux(e); + }, it.second); + CHECK_EQ(FreeVars(ret).size(), 0); + updates.Set(it.first, Downcast(ret)); + } - DLOG(INFO) - << "ToANF: transformed" << std::endl - << AsText(ret, false); + for (auto pair : updates) { + m->Add(pair.first, pair.second, true); + } - return ret; -} + DLOG(INFO) << "ToANF: transformed" << std::endl << m; -Expr ToANormalForm(const Expr& e, const Module& m) { - std::unordered_set gv; - return ToANormalForm(e, m, &gv); + return m; } -TVM_REGISTER_API("relay._ir_pass.to_a_normal_form") -.set_body_typed(static_cast(ToANormalForm)); - namespace transform { Pass ToANormalForm() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(ToANormalForm(f, m)); + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return ToANormalForm(m); }; - return CreateFunctionPass(pass_func, 1, "ToANormalForm", {}); + return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } TVM_REGISTER_API("relay._transform.ToANormalForm") diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 9c166f98c1a5c..c1ae19e92748e 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -24,8 +24,8 @@ * * \brief Turn A normal form into graph normal form. */ -#include #include +#include #include "let_list.h" namespace tvm { @@ -76,9 +76,6 @@ Expr ToGraphNormalForm(const Expr& e) { return GNF()(e); } -TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form") -.set_body_typed(ToGraphNormalForm); - namespace transform { Pass ToGraphNormalForm() { diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 9158f0729d614..c3b12fea44867 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -18,20 +18,13 @@ import tvm from tvm import relay -from tvm.relay.ir_pass import dead_code_elimination, alpha_equal +from tvm.relay import Function, transform +from tvm.relay.ir_pass import alpha_equal, graph_equal, free_vars from tvm.relay.op import log, add, equal, subtract class env: def __init__(self): - self.a = relay.Var("a") - self.b = relay.Var("b") - self.c = relay.Var("c") - self.d = relay.Var("d") - self.e = relay.Var("e") - self.x = relay.Var("x") - self.y = relay.Var("y") - self.z = relay.Var("z") self.shape = tvm.convert([1, 2, 3]) self.tt = relay.TensorType(self.shape, "float32") self.int32 = relay.TensorType([], "int32") @@ -39,6 +32,14 @@ def __init__(self): self.one = relay.const(1.0) self.two = relay.const(2.0) self.three = relay.const(3.0) + self.a = relay.Var("a", self.float32) + self.b = relay.Var("b", self.float32) + self.c = relay.Var("c", self.float32) + self.d = relay.Var("d", self.float32) + self.e = relay.Var("e", self.float32) + self.x = relay.Var("x", self.int32) + self.y = relay.Var("y", self.int32) + self.z = relay.Var("z", self.int32) e = env() @@ -46,22 +47,27 @@ def __init__(self): def test_let(): orig = relay.Let(e.x, e.y, e.z) - assert alpha_equal(dead_code_elimination(orig), e.z) + orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z)) def test_used_let(): orig = relay.Let(e.c, e.one, e.c + e.c) - assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c)) + orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + expected = relay.Let(e.c, e.one, e.c + e.c) + assert alpha_equal(Function([e.c], orig), Function([e.c], expected)) @nottest def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) - assert alpha_equal(dead_code_elimination(orig), e.d) + orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d)) def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) - assert alpha_equal(dead_code_elimination(orig), e.e) + orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e)) # make sure we dont infinite loop @@ -78,27 +84,39 @@ def test_recursion(): f(2, 10000); """ f = relay.Var("f") + f1 = relay.Var("f1") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If(equal(n, relay.const(0)), data, - relay.Call(f, [subtract(n, relay.const(1.0)), + relay.Call(f1, [subtract(n, relay.const(1)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) - orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) - assert alpha_equal(dead_code_elimination(orig), orig) - assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three) + orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)])) + dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + orig = transform.OptimizeOnExpr(orig, transform.InferType()) + assert graph_equal(dced, orig) + dced = transform.OptimizeOnExpr(relay.Let(f, value, e.three), + transform.DeadCodeElimination()) + assert alpha_equal(dced, e.three) def test_op_let(): - assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two)) + dced = transform.OptimizeOnExpr(add(relay.Let(e.a, e.one, e.three), e.two), + transform.DeadCodeElimination()) + assert alpha_equal(dced, add(e.three, e.two)) def test_tuple_get_item(): - t = relay.Var('t') + tt = relay.TupleType([e.float32, e.float32]) + t = relay.Var('t', tt) + a = relay.Var('a') g = relay.TupleGetItem(t, 0) - assert alpha_equal(dead_code_elimination(g), g) - assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g) + dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) + orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) + dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index b3c0c28d26cb8..f2aedd1905d4e 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -18,17 +18,13 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination -from tvm.relay.ir_pass import gradient -from tvm.relay import op, create_executor -from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue +from tvm.relay.ir_pass import alpha_equal, gradient from tvm.relay.prelude import Prelude -from tvm.relay import create_executor -from nose.tools import nottest +from tvm.relay import op, create_executor, transform from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match -from tvm.relay import GlobalVar, Call, Type -from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr +from tvm.relay import GlobalVar, Call +from tvm.relay.testing import add_nat_definitions, make_nat_expr def check_eval(expr, expected_result, mod=None, rtol=1e-07): ctx = tvm.context("llvm", 0) @@ -38,8 +34,25 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07): np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) -def dcpe(expr, mod=None): - return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True) +def tipe(expr): + return transform.OptimizeOnExpr(expr, + [transform.InferType(), + transform.PartialEvaluate(), + transform.InferType()]) + + +def dcpe(expr, mod=None, grad=False): + passes = [transform.PartialEvaluate(), + transform.DeadCodeElimination(inline_once=True)] + if grad: + expr = gradient(expr) + if mod: + assert isinstance(expr, Function) + mod[mod.entry_func] = expr + seq = transform.Sequential(passes) + mod = seq(mod) + return mod[mod.entry_func] + return transform.OptimizeOnExpr(expr, passes) def test_tuple(): @@ -47,24 +60,31 @@ def test_tuple(): x = Var("x", t) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = Function([x], body, None, [t]) - assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t])) + expected = relay.Function([x], x, None, [t]) + expected = transform.OptimizeOnExpr(expected, transform.InferType()) + assert alpha_equal(dcpe(f), expected) + def test_const_inline(): - d = Var("d") + t = relay.TensorType([], "float32") + d = Var("d", t) double = Function([d], d + d) orig = double(const(4.0)) assert alpha_equal(dcpe(orig), const(8.0)) def test_ref(): - d = relay.Var("d") - r = relay.Var("r") + t = relay.TensorType([], "float32") + d = relay.Var("d", t) + r = relay.Var("r", relay.RefType(t)) x = relay.Var("x") body = relay.RefRead(r) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) - assert alpha_equal(dcpe(square), Function([d], d * d)) + expected = transform.OptimizeOnExpr(Function([d], d * d), + transform.InferType()) + assert alpha_equal(dcpe(square), expected) def test_empty_ad(): @@ -73,17 +93,19 @@ def test_empty_ad(): t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) - g = dcpe(gradient(f)) + g = dcpe(f, grad=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) + expected = transform.OptimizeOnExpr(expected, transform.InferType()) assert alpha_equal(g, expected) + def test_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d * d) - g = dcpe(gradient(f)) + g = dcpe(f, grad=True) m = d * d x = relay.Var("x") o = op.ones_like(x) @@ -92,6 +114,7 @@ def test_ad(): body = Tuple([x, Tuple([grad])]) body = relay.Let(x1, o, body) expected = Function([d], relay.Let(x, m, body)) + expected = transform.OptimizeOnExpr(expected, transform.InferType()) assert alpha_equal(g, expected) @@ -107,8 +130,7 @@ def test_if_ref(): eff = Var("eff") body = Let(eff, body, RefRead(r)) f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body))) - f = infer_type(f) - pe_f = infer_type(partial_evaluate(f)) + pe_f = tipe(f) ex = create_executor() f_res = ex.evaluate(f)(const(True)) pe_f_res = ex.evaluate(pe_f)(const(True)) @@ -132,8 +154,7 @@ def test_function_invalidate(): body = Let(fet, fetch, body) body = Let(r, RefCreate(const(0)), body) f = Function([d], body) - f = infer_type(f) - pe_f = infer_type(partial_evaluate(f)) + pe_f = tipe(f) ex = create_executor() f_res = ex.evaluate(f)(const(True)) pe_f_res = ex.evaluate(pe_f)(const(True)) @@ -144,35 +165,30 @@ def test_function_invalidate(): def test_head_cons(): mod = Module() p = Prelude(mod) - def hd_impl(): - a = TypeVar("a") - x = Var("x", p.l(a)) - y = Var("y") - z = Var("z") - cons_case = Clause(PatternConstructor(p.cons, - [PatternVar(y), - PatternVar(z)]), - y) - y = Var("y") - z = Var("z") - return Function([x], Match(x, [cons_case]), a, [a]) + hd = p.hd t = TypeVar("t") x = Var("x", t) - hd = Var("hd") - body = Let(hd, hd_impl(), hd(p.cons(x, p.nil()))) + body = hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) - f = infer_type(f, mod=mod) - res = dcpe(f) + res = dcpe(f, mod) assert alpha_equal(res, Function([x], x, t, [t])) def test_map(): mod = Module() p = Prelude(mod) - f = Var("f") + f = GlobalVar("f") + t = TypeVar("t") + a = Var("a", t) + mod[f] = Function([a], a, t, [t]) orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil())))) - expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil()))) - assert alpha_equal(dcpe(orig, mod=mod), expected) + expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil()))) + expected = Function([], expected) + mod[mod.entry_func] = expected + expected = mod[mod.entry_func] + orig = Function([], orig) + res = dcpe(orig, mod=mod) + assert alpha_equal(res.body, expected.body) def test_loop(): @@ -181,9 +197,12 @@ def test_loop(): x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) - res = dcpe(loop(const(1)), mod=mod) - expected = Call(loop, [const(1)], None, [None]) - assert alpha_equal(res, expected) + expected = Call(loop, [const(1)]) + mod[mod.entry_func] = Function([], expected) + expected = mod[mod.entry_func].body + call = Function([], loop(const(1))) + res = dcpe(call, mod=mod) + assert alpha_equal(res.body, expected) def test_swap_loop(): @@ -196,8 +215,9 @@ def test_swap_loop(): loop = GlobalVar("loop") mod[loop] = Function([x, y], loop(y, x), nat) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) - res = dcpe(prog, mod=mod) - assert alpha_equal(prog, res) + res = Function([], prog) + res = dcpe(res, mod=mod) + assert alpha_equal(prog, res.body) def test_abs_diff(): @@ -217,8 +237,9 @@ def test_abs_diff(): x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 4)) + assert alpha_equal(res.body, make_nat_expr(p, 4)) def test_match_nat_id(): @@ -233,8 +254,9 @@ def test_match_nat_id(): s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y)) mod[nat_id] = Function([x], Match(x, [z_case, s_case])) orig = nat_id(make_nat_expr(p, 3)) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 3)) + assert alpha_equal(res.body, make_nat_expr(p, 3)) def test_nat_id(): @@ -247,8 +269,9 @@ def test_nat_id(): nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 3)) + assert alpha_equal(res.body, make_nat_expr(p, 3)) def test_global_match_nat_id(): @@ -260,8 +283,9 @@ def test_global_match_nat_id(): z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 3)) + assert alpha_equal(res.body, make_nat_expr(p, 3)) def test_double(): @@ -269,8 +293,9 @@ def test_double(): p = Prelude(mod) add_nat_definitions(p) orig = p.double(make_nat_expr(p, 3)) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 6)) + assert alpha_equal(res.body, make_nat_expr(p, 6)) if __name__ == '__main__': diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 9a2570eabb11b..e74168141e63c 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -17,9 +17,8 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type, detect_feature -from tvm.relay import op, create_executor -from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue +from tvm.relay.ir_pass import alpha_equal, detect_feature +from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count from tvm.relay.feature import Feature @@ -39,7 +38,7 @@ def test_explicit_bound(): z = op.add(y, y) f = relay.Function([], op.add(z, z)) assert not Feature.fLet in detect_feature(f) - anf = to_a_normal_form(f) + anf = transform.OptimizeOnExpr(f, transform.ToANormalForm()) assert Feature.fLet in detect_feature(anf) check_eval(f(), 8.0) check_eval(anf(), 8.0) @@ -53,7 +52,8 @@ def test_order(): x = relay.const(1) val = x + y * z check_eval(val, 7.0) - anf = infer_type(to_a_normal_form(val)) + anf = transform.OptimizeOnExpr(val, [transform.ToANormalForm(), + transform.InferType()]) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) @@ -65,14 +65,16 @@ def test_order(): expected_output = relay.Let(c, z, expected_output) expected_output = relay.Let(b, y, expected_output) expected_output = relay.Let(a, x, expected_output) - expected_output = infer_type(expected_output) + expected_output = transform.OptimizeOnExpr(expected_output, + transform.InferType()) assert alpha_equal(anf, expected_output) def test_if(): cond = relay.const(True) x = relay.If(cond, relay.const(2), relay.const(3)) - anf = infer_type(to_a_normal_form(x)) + anf = transform.OptimizeOnExpr(x, [transform.ToANormalForm(), + transform.InferType()]) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) @@ -82,7 +84,8 @@ def test_if(): expected_output = relay.If(c, true_branch, false_branch) expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(c, cond, expected_output) - expected_output = infer_type(expected_output) + expected_output = transform.OptimizeOnExpr(expected_output, + transform.InferType()) assert alpha_equal(anf, expected_output) @@ -114,7 +117,8 @@ def test_recursion(): mod[f] = value check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) old_f = mod[f] - f = to_a_normal_form(f, mod=mod) + mod = transform.ToANormalForm()(mod) + f = mod[f] check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) @@ -129,7 +133,8 @@ def test_ref(): body = relay.Let(iv, relay.RefRead(i), body) body = relay.Let(i, relay.RefCreate(relay.const(1)), body) check_eval(body, 3) - check_eval(to_a_normal_form(body), 3) + opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm()) + check_eval(opt_body, 3) def test_nat_add(): @@ -144,7 +149,12 @@ def test_nat_add(): intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 - assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 + expr = add(s(z()), s(z())) + f = relay.GlobalVar("f") + mod[f] = relay.Function([], expr) + mod = transform.ToANormalForm()(mod) + expr = mod["f"] + assert count(p, intrp.evaluate(expr.body)) == 2 assert Feature.fLet in detect_feature(mod[add]) @@ -155,14 +165,16 @@ def test_let(): body = relay.Let(y, x, x + y) body = relay.Let(x, d, body) check_eval(body, 8) - check_eval(to_a_normal_form(body), 8) + opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm()) + check_eval(opt_body, 8) def test_function(): - x = relay.Var("x") + t = relay.TensorType((), 'float32') + x = relay.Var("x", t) f = relay.Function([x], x + x) d = relay.const(4.0, 'float32') - anf_f = to_a_normal_form(f) + anf_f = transform.OptimizeOnExpr(f, transform.ToANormalForm()) assert isinstance(anf_f, relay.Function) check_eval(f(d), 8) check_eval(anf_f(d), 8) diff --git a/tests/python/relay/test_pass_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py index 6d9bd6ac254ec..09db48f633d91 100644 --- a/tests/python/relay/test_pass_to_graph_normal_form.py +++ b/tests/python/relay/test_pass_to_graph_normal_form.py @@ -17,10 +17,9 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal, detect_feature -from tvm.relay import op, create_executor +from tvm.relay import op, create_executor, transform +from tvm.relay.ir_pass import detect_feature from tvm.relay.feature import Feature -from tvm.relay.backend.interpreter import Value, TupleValue def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): @@ -41,9 +40,9 @@ def test_implicit_share(): body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) - g = to_graph_normal_form(f) - assert "let" in f.astext() - assert not "let" in g.astext() + g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm()) + assert Feature.fLet in detect_feature(f) + assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) check_eval(g, [], 8.0) @@ -55,8 +54,8 @@ def test_round_trip(): body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) - g = to_graph_normal_form(f) - h = to_a_normal_form(g) + g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm()) + h = transform.OptimizeOnExpr(g, transform.ToANormalForm()) assert Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) From 4273e461f2a3bb4ea6f82e90ac025d2bd04712de Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 1 Jul 2019 14:07:45 -0700 Subject: [PATCH 02/26] Migrate simplifier to new infra. (#3368) --- CMakeLists.txt | 6 +- include/tvm/arithmetic.h | 9 +- include/tvm/ir_pass.h | 1 - src/arithmetic/analyzer.cc | 1 + src/arithmetic/bound_deducer.cc | 145 ++++++++++-------- src/arithmetic/const_fold.h | 7 +- src/arithmetic/rewrite_simplify.cc | 1 - src/arithmetic/stmt_simplify.cc | 41 +---- src/lang/buffer.cc | 5 +- src/op/scan_op.cc | 6 +- src/pass/loop_partition.cc | 21 ++- src/pass/narrow_channel_access.cc | 7 +- src/pass/storage_rewrite.cc | 8 +- src/pass/vectorize_loop.cc | 9 +- src/schedule/message_passing.cc | 16 +- src/schedule/schedule_dataflow_rewrite.cc | 2 +- tests/cpp/ir_simplify_test.cc | 10 +- .../unittest/test_arith_deduce_bound.py | 31 ++-- tests/python/unittest/test_pass_basic.py | 3 - 19 files changed, 175 insertions(+), 154 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6500ba013e28f..c23d403bcb6a1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS file(GLOB TOPI_SRCS topi/src/*.cc ) -file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp) +file(GLOB_RECURSE HALIDEIR_SRCS + 3rdparty/HalideIR/src/base/*.cpp + 3rdparty/HalideIR/src/ir/*.cpp + 3rdparty/HalideIR/src/tvm/*.cpp +) list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS}) file(GLOB RUNTIME_SRCS src/runtime/*.cc diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 92f7399a89a57..446c4c0c19a91 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -623,12 +623,15 @@ IntSet Intersect(const Array& sets); * give the domain of each variables. Return undefined IntSet to * represent failure. * + * \note The returned set may be smaller than set that + * contains all possible values of v that satisfies the bound. + * * \param v The target variable to be deduced. * \param cond The conditional expression. * \param hint_map The domain of variable, used to help deduce. * \param relax_map The domain of each variable, used to relax the domain, - * The deduce bound mush implies e for all value in relax_map - * \return An integer set that can cover all the possible values. + * The deduce bound must implies e for all value in relax_map + * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, const Map& hint_map, @@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond, * \param hint_map The domain of variable, used to help deduce. * \param relax_map The domain of each variable, used to relax the domain, * The deduce bound mush implies e for all value in relax_map - * \return An integer set that can cover all the possible values. + * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, const std::unordered_map& hint_map, diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index e1c92e50e6ad1..98dbf6bb62906 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -27,7 +27,6 @@ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ -#include #include #include #include diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 2198aee934787..626fc18c57df9 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) { Expr Analyzer::Simplify(const Expr& expr) { if (is_const(expr)) return expr; auto res = this->rewrite_simplify(expr); + if (is_const(res)) return res; res = this->canonical_simplify(res); return res; } diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 395a371f43af7..003ba8def7612 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor { void Deduce(); void Visit(const NodeRef& e) final { - if (!success) return; + if (!success_) return; if (e.get() == path_[iter_++]) { IRVisitor::Visit(e); } else { - success = false; + success_ = false; return; } } @@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor { void Visit_(const Add* op) final { bool left = op->a.get() == path_[iter_]; - result -= left ? op->b : op->a; + result_ -= left ? op->b : op->a; Visit(left ? op->a : op->b); } void Visit_(const Sub* op) final { bool left = op->a.get() == path_[iter_]; if (left) { - result += op->b; + result_ += op->b; } else { - result -= op->a; - result = - result; - is_greater = !is_greater; + result_ -= op->a; + result_ = - result_; + is_greater_ = !is_greater_; } Visit(left ? op->a : op->b); } @@ -130,43 +130,65 @@ class BoundDeducer: public IRVisitor { void Visit_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; + Expr target_var = left ? op->a : op->b; - SignType sign; + SignType sign_operand; if (operand.type().is_uint()) { - sign = kPositive; + sign_operand = kPositive; } else { - sign = expr_map_[operand].sign_type(); + sign_operand = expr_map_[operand].sign_type(); } - if (sign == SignType::kNegative) { - is_greater = !is_greater; - } else if (sign == SignType::kUnknown) { + if (sign_operand == SignType::kNegative) { + is_greater_ = !is_greater_; + } else if (sign_operand == SignType::kUnknown) { // unable to get the sign of operand - success = false; + success_ = false; return; } - // always use relax bound - bool divided = can_prove(result % operand == 0); - result = result / operand; - // since system will round down when not divided - // eg. 2/4 -> 0; -2/4 -> -1 - // no need fix for !is_greater: - // eg. a <= 2/4 -> a <= 0 - // eg. a <= 0/4 -> a <= 0 - // so just fix for not divided and is_greater - // eg. a >= 2/4 -> a >= 0 + 1 - // eg. a >= 0/4 -> a >= 0 - if (is_greater && !divided) { - result += 1; + bool divided = analyzer_.CanProve(result_ % operand == 0); + + result_ = result_ / operand; + + if (!divided) { + // Handle non-divisible case + // NOTE: this accounts for truc div behavior. + bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative(); + + if (is_greater_) { + result_ += 1; + } else { + // NOTE: this is a bit sutble hack. + // + // condition: + // - x * operand <= result + // - operand > 0 + // - x >= 0 + // + // Then it is fine to deduce that x <= result / operand. + // - if result > 0, this division round down + // - if result < 0, (result / operand) rounds up and may violate the constraint + // however, given that x is always non-negative, + // it is fine to have this relaxed bound, given that the user of deduce bound + // will respect the bound of x + // + // TODO(tvm-team): think about a better API to incorporate constraint of x. + // e.g. specify an interval of x and return a bound + // that is in the interval and satisfies the condition. + if (target_is_non_neg && sign_operand == kPositive) { + // do nothing + } else { + result_ -= 1; + } + } } - Visit(left ? op->a : op->b); } - Expr result; - bool is_greater{true}; - bool success{true}; + Expr result_; + bool is_greater_{true}; + bool success_{true}; private: void Init(); @@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor { ExprIntSetMap expr_map_; std::vector path_; size_t iter_{0}; + // internal analzyer + Analyzer analyzer_; }; class BoundDeduceInputChecker: public IRVisitor { @@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor { void BoundDeducer::Init() { BoundDeduceInputChecker checker; - if (!checker.Check(this)) success = false; + if (!checker.Check(this)) success_ = false; Transform(); } @@ -211,66 +235,65 @@ void BoundDeducer::Transform() { if (const LT* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a < b -> b >= a + 1 - is_greater = true; + is_greater_ = true; expr_ = op->b; - result = op->a + 1; + result_ = op->a + 1; } else { // a < b -> a <= b - 1 - is_greater = false; + is_greater_ = false; expr_ = op->a; - result = op->b - 1; + result_ = op->b - 1; } } else if (const LE* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a <= b -> b >= a - is_greater = true; + is_greater_ = true; expr_ = op->b; - result = op->a; + result_ = op->a; } else { - is_greater = false; + is_greater_ = false; expr_ = op->a; - result = op->b; + result_ = op->b; } } else if (const GT* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a > b -> b <= a - 1 - is_greater = false; + is_greater_ = false; expr_ = op->b; - result = op->a - 1; + result_ = op->a - 1; } else { // a > b -> a >= b + 1 - is_greater = true; + is_greater_ = true; expr_ = op->a; - result = op->b + 1; + result_ = op->b + 1; } } else if (const GE* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a >= b -> b <= a - is_greater = false; + is_greater_ = false; expr_ = op->b; - result = op->a; + result_ = op->a; } else { - is_greater = true; + is_greater_ = true; expr_ = op->a; - result = op->b; + result_ = op->b; } } else { - success = false; + success_ = false; } } void BoundDeducer::Deduce() { Init(); - if (!success) return; + if (!success_) return; Relax(); - if (!success) return; + if (!success_) return; // get the path path_ = GetPath(target_, expr_); if (!path_.size()) { - success = false; + success_ = false; return; } - expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); Visit(expr_); @@ -278,13 +301,13 @@ void BoundDeducer::Deduce() { void BoundDeducer::Relax() { IntSet a = EvalSet(expr_, relax_map_); - IntSet b = EvalSet(result, relax_map_); + IntSet b = EvalSet(result_, relax_map_); if (a.is_everything() || b.is_everything()) { - success = false; + success_ = false; return; } - expr_ = is_greater ? a.min() : a.max(); - result = is_greater ? b.max() : b.min(); + expr_ = is_greater_ ? a.min() : a.max(); + result_ = is_greater_ ? b.max() : b.min(); } IntSet DeduceBound(Expr v, Expr e, @@ -292,12 +315,12 @@ IntSet DeduceBound(Expr v, Expr e, const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); - if (!d.success) return IntSet::nothing(); + if (!d.success_) return IntSet::nothing(); Expr min = neg_inf(), max = pos_inf(); - if (d.is_greater) { - min = d.result; + if (d.is_greater_) { + min = d.result_; } else { - max = d.result; + max = d.result_; } return IntSet::interval(min, max); } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index ec50aef5c51ed..dc6b80a31c7bd 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -155,9 +155,10 @@ template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); - // due to division and mod can have different modes - // only constant fold positive number where rule is fixed. - if (pa && pb && pa->value >= 0 && pb->value > 0) { + if (pa && pb) { + // due to division and mod can have different modes + // NOTE: this will assumes truc div. + CHECK_NE(pb->value, 0) << "Divide by zero"; return IntImm::make(rtype, pa->value / pb->value); } if (pa) { diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index bc8666e893b4e..6cc829d07e887 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) { TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); - TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index 01cb96ee1323e..162cb1e5fd164 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -28,7 +28,6 @@ #include #include #include -#include "arithmetic/Simplify.h" namespace tvm { namespace arith { @@ -158,42 +157,18 @@ Expr CanonicalSimplify(Expr expr, Map vrange) { return analyzer.canonical_simplify(expr); } -template -T Simplify_(T a, Map vrange) { - using namespace HalideIR::Internal; - Scope rscope; +Expr Simplify(Expr expr, Map vrange) { + arith::Analyzer analyzer; for (auto kv : vrange) { - Range r = kv.second; - rscope.push( - kv.first.get(), - Interval(r->min, - simplify(r->min + r->extent - make_const(r->min.type(), 1)))); - } - return HalideIR::Internal::simplify(a, true, rscope); -} - - -Expr Simplify(Expr a, Map vrange) { - // Simplify top level reduce. - if (const Reduce* r = a.as()) { - Array new_source; - for (auto& e : r->source) { - new_source.push_back(Simplify_(e, vrange)); - } - Expr new_condition = Simplify_(r->condition, vrange); - if (r->source.same_as(new_source) && - r->condition.same_as(new_condition)) { - return a; - } else { - return Reduce::make( - r->combiner, new_source, r->axis, new_condition, r->value_index); - } + analyzer.Bind(kv.first, kv.second); } - return Simplify_(a, vrange); + expr = analyzer.Simplify(expr); + return expr; } -Stmt Simplify(Stmt a, Map vrange) { - return Simplify_(a, vrange); +Stmt Simplify(Stmt stmt, Map vrange) { + return arith::CanonicalStmtSimplifier().CanonicalSimplify( + stmt, vrange); } } // namespace ir } // namespace tvm diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 8c584c50b3c67..3e0615162a8f8 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,6 +26,7 @@ #include #include #include +#include #include "../arithmetic/compute_expr.h" namespace tvm { diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index 42b1331e3736c..78f8c82d97dbf 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name, for (size_t i = 0; i < init.size(); ++i) { CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); CHECK_EQ(init[i]->dtype, update[i]->dtype); - CHECK(can_prove(init[i]->shape[0] == axis->dom->min)) + CHECK(prove_equal(init[i]->shape[0], axis->dom->min)) << "init.shape[0] need to match scan_axis.dom.min"; CHECK(prove_equal( state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 0a5b7410f3cff..33dbaed83b697 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -466,8 +466,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Stmt body, bool partition_thread_scope) { using namespace arith; + // include hint of var. + hint_map_.insert({var.get(), IntSet::interval(min, max)}); + PartitionFinder finder(var, hint_map_, relax_map_); finder.Visit(body); + + hint_map_.erase(var.get()); if (finder.partitions.empty()) return Stmt(); arith::IntervalSet for_interval(min, max); @@ -504,9 +509,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); - if (!can_prove(body_begin == min)) { + if (!analyzer_.CanProve(body_begin == min)) { Expr cond = (body_begin - min >= 0); - if (!can_prove(cond)) { + if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; body_begin = Max::make(body_begin, min); @@ -529,10 +534,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); - if (!can_prove(middle_interval.max() == max)) { + if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative Expr cond = (max - post_doubt_begin + 1 >= 0); - if (!can_prove(cond)) { + if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; post_doubt_begin = Min::make(post_doubt_begin, max); @@ -554,7 +559,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, // Generating code for middle subrange if (!partition_thread_scope) { Stmt mid_stmt; - if (!can_prove(body_begin >= post_doubt_begin)) { + if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) { // [body_begin, post_doubt_begin) Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); @@ -576,8 +581,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node, s = AppendStmts(s, post_stmt); } else { Expr cond = const_true(); - if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin); - if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); + if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); + if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt); } s = ConvertSSA(s); @@ -587,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { const For *for_node = static_cast(node); CHECK(for_node); - if (can_prove(extent == make_const(Int(32), 1))) { + if (analyzer_.CanProve(extent == make_const(Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}}); } else { diff --git a/src/pass/narrow_channel_access.cc b/src/pass/narrow_channel_access.cc index 731064edb0121..57f3baf20e108 100644 --- a/src/pass/narrow_channel_access.cc +++ b/src/pass/narrow_channel_access.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator { Expr base = linear_eq[1]; if (!is_zero(base)) return body; Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent); - if (!can_prove(left >= 0)) return body; + if (!analyzer_.CanProve(left >= 0)) return body; // rewrite access index. ChannelAccessIndexRewriter rw( ch->handle_var.get(), var * coeff, read_access); @@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator { return body; } + arith::Analyzer analyzer_; std::vector tasks_; }; diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 806a80ad4dc90..eba1cee8b7c70 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator { } // transform to alloc bytes auto type_bits = alloc_type.bits() * alloc_type.lanes(); - bool divided = can_prove(combo_size % type_bits == 0); + bool divided = analyzer_.CanProve(combo_size % type_bits == 0); combo_size = combo_size / type_bits; // round up for can not divided if (!divided) { @@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator { std::unordered_map alloc_map_; // The allocations std::vector > alloc_vec_; + // analyzer + arith::Analyzer analyzer_; }; // Turn alloc into vector alloc diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 8c3d383c1529a..a48e8b4d7e83d 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -132,11 +133,11 @@ class Vectorizer : public IRMutator { if (lanes != 1) { const Ramp* b_ramp = b.as(); const Ramp* a_ramp = a.as(); - if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) { + if (a_ramp && b.type().lanes() == 1 && analyzer_.CanProve(b > 0)) { return Ramp::make( a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); } - if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) { + if (b_ramp && a.type().lanes() == 1 && analyzer_.CanProve(a > 0)) { return Ramp::make( b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); } @@ -186,7 +187,7 @@ class Vectorizer : public IRMutator { Expr stride = this->Mutate(op->stride); if (base.type().lanes() > 1 && stride.type().lanes() == 1) { const Ramp* base_ramp = base.as(); - if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) { + if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) { return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes); } } @@ -423,6 +424,8 @@ class Vectorizer : public IRMutator { } private: + // analyzer + arith::Analyzer analyzer_; // variable to be replaced Var var_; // the lanes. diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index a7f974613aa15..0dc82abd9a8f5 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -432,9 +432,9 @@ void PassDownBitMaskOr(const Stage& stage, */ void PassUpBoundCheck(const Stage& s, const Map& dom_map, - std::unordered_map* p_state) { + std::unordered_map* p_state, + arith::Analyzer* analyzer) { auto& state = *p_state; - using HalideIR::Internal::can_prove; for (size_t i = s->relations.size(); i != 0; --i) { IterVarRelation rel = s->relations[i - 1]; if (const SplitNode* s = rel.as()) { @@ -447,7 +447,7 @@ void PassUpBoundCheck(const Stage& s, if (outer || inner) { state[s->parent] = true; } else { - if (can_prove(dom_map.at(s->parent)->extent == factor * step)) { + if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) { state[s->parent] = false; } else { state[s->parent] = true; @@ -476,11 +476,13 @@ std::vector MakeBoundCheck( const std::unordered_map& value_map, bool skip_ivar_domain, const std::unordered_set& skip_iter) { + Analyzer analyzer; + std::unordered_map bound_state; for (IterVar iv : stage->leaf_iter_vars) { bound_state[iv] = false; } - PassUpBoundCheck(stage, dom_map, &bound_state); + PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); std::vector preds; std::unordered_map iset_dmap; @@ -496,7 +498,7 @@ std::vector MakeBoundCheck( Range dom = dom_map.at(iv); Expr value = ComputeExpr(value_map.at(iv), dom->min); Expr vmax = EvalSet(value, iset_dmap).max(); - if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) { + if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) { preds.emplace_back(value < dom->extent); } } @@ -511,10 +513,10 @@ std::vector MakeBoundCheck( Expr vmin = s.min(); Expr vmax = s.max(); // The range of `value` resides in [vmin, vmax] - if (vmin.type() != value.type() || !can_prove(vmin >= 0)) { + if (vmin.type() != value.type() || !analyzer.CanProve(vmin >= 0)) { preds.emplace_back(value >= 0); } - if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { + if (vmax.type() != value.type() || !analyzer.CanProve(vmax < iv->dom->extent)) { preds.emplace_back(value < iv->dom->extent); } } diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index c5f1b1656dd5f..760ed0f233f7e 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -740,7 +740,7 @@ Array Schedule::rfactor(const Tensor& tensor, const Reduce* reduce = compute_op->body[idx].as(); CHECK(reduce) << "Can only rfactor non-inline reductions"; predicates.push_back(reduce->condition); - Expr predicate = likely(simplify(arith::ComputeReduce(predicates, Expr()))); + Expr predicate = likely(arith::ComputeReduce(predicates, Expr())); std::unordered_map vsub; diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/ir_simplify_test.cc index 35968f8524de6..5a5dc03f0165b 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/ir_simplify_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,12 +21,6 @@ #include #include #include -#include - -TEST(IRSIMPLIFY, Basic) { - using namespace HalideIR::Internal; - simplify_test(); -} TEST(IRSIMPLIFY, MinMax) { auto x = tvm::var("x"); diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 7fe6f56edea78..d26b508ff262a 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -16,6 +16,14 @@ # under the License. import tvm + +def assert_expr_equal(a, b): + res = tvm.ir_pass.Simplify(a - b) + equal = isinstance(res, tvm.expr.IntImm) and res.value == 0 + if not equal: + raise ValueError("{} and {} are not equal".format(a, b)) + + def test_deduce(): a = tvm.var('a') b = tvm.var('b') @@ -29,31 +37,34 @@ def test_deduce(): e0 = (-b)*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((d - c) /(b*-1)) - assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + ans0 = ((d - c) /(b*-1) + (-1)) + assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + assert_expr_equal(res0.max_value, ans0) e0 = d*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((0-c)/d + 1) - assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + ans0 = ((d-c)/d - 1) + assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + assert_expr_equal(res0.max_value, ans0) + e1 = (a*4+b < c) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - ans1 = (((c - b) + -1)/4) - assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) + ans1 = (((c - b) + -1)/4 -1) + assert_expr_equal(res1.max_value, ans1) + # expression containing variable a is on rhs e1 = (c > a*4+b) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) + assert_expr_equal(res1.max_value, ans1) + e2 = (tvm.max(5, a * 4) < 0) res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) @@ -66,7 +77,6 @@ def test_deduce(): assert str(res2.max_value) == "neg_inf" assert str(res2.min_value) == "pos_inf" - e3 = (-b)+a*c-d res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) ans3 = 2/c+1 @@ -75,6 +85,7 @@ def test_deduce(): res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) + def test_check(): a = tvm.var('a') b = tvm.var('b') diff --git a/tests/python/unittest/test_pass_basic.py b/tests/python/unittest/test_pass_basic.py index fc76c306731c1..b05d75ab2d1ef 100644 --- a/tests/python/unittest/test_pass_basic.py +++ b/tests/python/unittest/test_pass_basic.py @@ -24,9 +24,6 @@ def test_simplify(): assert(tvm.ir_pass.Equal(e2, x * 8)) e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) - let = tvm.make.Let(x, 1, x + 3) - e4 = tvm.ir_pass.Simplify(let) - assert(tvm.ir_pass.Equal(e4, 4)) def test_verify_ssa(): From 9e1fcc863bb5e6fb75b3fb1ad054ca3ca51aad99 Mon Sep 17 00:00:00 2001 From: Yida Wang Date: Mon, 1 Jul 2019 15:09:50 -0700 Subject: [PATCH 03/26] [ANALYSIS] Mac count deconv (#3469) * add mac count for conv 2d transpose * add the explanation of missing parameter in docstring * typo * fix pylint --- python/tvm/relay/op/nn/nn.py | 6 +++ src/relay/pass/mac_count.cc | 45 +++++++++++++++++++++-- tests/python/relay/test_pass_mac_count.py | 31 +++++++++++++++- 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 7bce9dd3c5b99..1de86173040d0 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -137,6 +137,12 @@ def conv2d_transpose(data, dilation : Tuple[int], optional Specifies the dilation rate to be used for dilated convolution. + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + groups : int, optional Number of groups for grouped convolution. diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index 3d77fabe6fe91..ce70eb0512149 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -88,11 +88,44 @@ int64_t ConvMacCount(const Call& call_node) { << "The dimension of the output tensor in Conv 2D should be 4 or 5."; int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); CHECK_EQ(input_channel % conv_2d_attr->groups, 0) - << "The number of input channels is not divisble by groups."; + << "The number of input channels is not divisble by groups."; count *= input_channel/conv_2d_attr->groups; return count; } +int64_t Conv2dTransposeMacCount(const Call& call_node) { + if (!call_node->checked_type_.defined()) { + LOG(WARNING) << "The infer type pass should be called before the mac count pass"; + return 0; + } + Array args = call_node->args; + CHECK(args.size() == 2) + << "The number of input arguments of a CONV 2D Transpose node should be 2."; + const auto* conv_2d_transpose_attr = call_node->attrs.as(); + const auto* data_type = args[0]->checked_type().as(); + Array data_shape = data_type->shape; + std::string data_layout = conv_2d_transpose_attr->data_layout; + int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); + int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); + CHECK(C_ind != -1) + << "There is no input channel dimension."; + int64_t input_channel = static_cast(data_shape[C_ind].as()->value); + if (c_ind != -1) + input_channel *= static_cast(data_shape[c_ind].as()->value); + Array kernel_size = conv_2d_transpose_attr->kernel_size; + CHECK(kernel_size.size() == 2) + << "The dimension of the kernel in Conv 2D Transpose should be 2."; + const auto* expr = call_node->checked_type().as(); + Array output_tensor = expr->shape; + CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) + << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5."; + int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); + CHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0) + << "The number of input channels is not divisble by groups."; + count *= input_channel/conv_2d_transpose_attr->groups; + return count; +} + int64_t DenseMacCount(const Call& call_node) { if (!call_node->checked_type_.defined()) { LOG(WARNING) << "The infer type pass should be called before the mac count pass"; @@ -106,13 +139,13 @@ int64_t DenseMacCount(const Call& call_node) { Array data_shape = data_type->shape; Array weight_shape = weight_type->shape; CHECK(data_shape.size() == 2 && weight_shape.size() == 2) - << "The dimension of an input tensor to Dense node should be 2."; + << "The dimension of an input tensor to Dense node should be 2."; int64_t d1 = static_cast(data_shape[0].as()->value); int64_t d2 = static_cast(data_shape[1].as()->value); int64_t d3 = static_cast(weight_shape[0].as()->value); int64_t d4 = static_cast(weight_shape[1].as()->value); CHECK(d2 == d4) - << "The dimensions of input arguments do not match."; + << "The dimensions of input arguments do not match."; int64_t count = d1 * d2 * d3; return count; } @@ -120,6 +153,9 @@ int64_t DenseMacCount(const Call& call_node) { RELAY_REGISTER_OP("nn.conv2d") .set_attr("FMacCount", ConvMacCount); +RELAY_REGISTER_OP("nn.conv2d_transpose") +.set_attr("FMacCount", Conv2dTransposeMacCount); + RELAY_REGISTER_OP("nn.dense") .set_attr("FMacCount", DenseMacCount); @@ -129,7 +165,8 @@ class MacCounter : private ExprVisitor { count_ = 0; } static int64_t GetTotalMacNumber(const Expr& expr) { - LOG(INFO) << "This pass only counts MACs in direct CONV 2D and Dense ops"; + LOG(INFO) << "This pass only counts MACs in direct CONV 2D, " + << "CONV 2D Transpose and Dense ops"; MacCounter counter; counter(expr); return counter.count_; diff --git a/tests/python/relay/test_pass_mac_count.py b/tests/python/relay/test_pass_mac_count.py index 98ba1ad6325d9..a7739a6444733 100644 --- a/tests/python/relay/test_pass_mac_count.py +++ b/tests/python/relay/test_pass_mac_count.py @@ -55,7 +55,7 @@ def test_conv(): weight, channels=output_channel, kernel_size=(kh, kw), - padding=(1, 1)) + padding=(h_padding, w_padding)) func = relay.Function([data, weight], relay.Tuple(tvm.convert([conv2d]))) func = relay.ir_pass.infer_type(func) @@ -127,8 +127,37 @@ def test_depthwise_conv2d(): compute_count = relay.ir_pass.get_total_mac_number(func) assert compute_count == 2 * np.prod(dshape) * 3*3 +def test_conv_2d_transpose(): + batch_size = 1 + input_channel = 3 + h = 224 + w = 224 + output_channel = 64 + kh = 7 + kw = 7 + h_padding = 1 + w_padding = 1 + oh = h - h_padding * 2 + kh - 1 + ow = w - w_padding * 2 + kw - 1 + dshape = (batch_size, input_channel, h, w) + weight = relay.var("weight", shape=(input_channel, output_channel, kh, kw)) + data = relay.var("data", shape=dshape) + conv2d_transpose = relay.nn.conv2d_transpose( + data, + weight, + channels=output_channel, + kernel_size=(kh, kw), + padding=(h_padding, w_padding)) + func = relay.Function([data, weight], + relay.Tuple(tvm.convert([conv2d_transpose]))) + func = relay.ir_pass.infer_type(func) + compute_count = relay.ir_pass.get_total_mac_number(func) + expect_count = batch_size * input_channel * oh * ow * output_channel * kh * kw + assert compute_count == expect_count + if __name__ == "__main__": test_conv() test_gemm() test_simple_network() test_depthwise_conv2d() + test_conv_2d_transpose() From dfc1fb251b4845e06b8b45de2763913c16512cbf Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 1 Jul 2019 15:53:32 -0700 Subject: [PATCH 04/26] [RUNTIME] Only checks the custom data type if it is bigger than the specified range (#3471) --- include/tvm/runtime/packed_func.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 82b3dd4695415..17fd626ee51d0 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -962,10 +962,10 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { os << "bool"; return os; } - if (GetCustomTypeRegistered(t.code)) { - os << "custom[" << GetCustomTypeName(t.code) << "]"; - } else { + if (t.code < kCustomBegin) { os << TypeCode2Str(t.code); + } else { + os << "custom[" << GetCustomTypeName(t.code) << "]"; } if (t.code == kHandle) return os; os << static_cast(t.bits); @@ -987,10 +987,10 @@ inline std::string TVMType2String(TVMType t) { if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { return "bool"; } - if (GetCustomTypeRegistered(t.code)) { - repr += "custom[" + GetCustomTypeName(t.code) + "]"; - } else { + if (t.code < kCustomBegin) { repr += TypeCode2Str(t.code); + } else { + repr += "custom[" + GetCustomTypeName(t.code) + "]"; } if (t.code == kHandle) return repr; repr += std::to_string(static_cast(t.bits)); From 77445311540c0dfa7b124304b5cf89da6f2c210f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Mon, 1 Jul 2019 19:29:56 -0700 Subject: [PATCH 05/26] [Relay] fix 'please use input parameter mod warning' triggered in build_module (#3452) --- .../graph_tuner/utils/traverse_graph.py | 4 +- python/tvm/autotvm/task/relay_integration.py | 6 ++- python/tvm/relay/backend/interpreter.py | 2 + python/tvm/relay/build_module.py | 15 ++++--- python/tvm/relay/module.py | 4 +- .../relay/test_autotvm_task_extraction.py | 4 +- .../relay/test_backend_compile_engine.py | 6 +-- .../relay/test_backend_graph_runtime.py | 4 +- tests/python/relay/test_cpp_build_module.py | 4 +- tests/python/relay/test_pass_fuse_ops.py | 45 ++++++++----------- 10 files changed, 45 insertions(+), 49 deletions(-) diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index c0debaedede0d..7e7f1749eae73 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -127,10 +127,10 @@ def _traverse_expr(node): free_var = relay.Var("var_%d" % i, input_type) params.append(free_var) call = relay.Call(node.op, params, node.attrs) - func = relay.Function(params, call) + mod = relay.Module.from_expr(relay.Function(params, call)) relay.backend.compile_engine.get().clear() build_thread = threading.Thread(target=relay.build, - args=(func, + args=(mod, "llvm -device=tracing", None, None)) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index d80443a208d66..5b0294ef2d07d 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -105,8 +105,9 @@ def extract_from_program(func, params, ops, target, target_host=None): relay.backend.compile_engine.get().clear() # wrap build call in thread to avoid multiprocessing problems + mod = relay.Module.from_expr(func) build_thread = threading.Thread(target=_build, - args=(func, + args=(mod, target, target_host, params)) @@ -183,8 +184,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None): for func, param in zip(funcs, params): relay.backend.compile_engine.get().clear() # wrap build call in thread to avoid multiprocessing problems + mod = relay.Module.from_expr(func) build_thread = threading.Thread(target=my_build, - args=(func, + args=(mod, target, target_host, params)) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index c54a65b78fb23..cf643f61243cc 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -163,6 +163,8 @@ def _convert_args(self, expr, args, kwargs): args: List[tvm.NDArray] The new arguments with all keyword arguments placed in the correct slot. """ + assert expr is not None + if not kwargs: return args diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index cdda17aa517b6..6337e629516c3 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -25,7 +25,6 @@ from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt from . import _build_module -from . import ir_pass from . import ty as _ty from . import expr as _expr from .module import Module as _Module @@ -227,23 +226,23 @@ class GraphExecutor(_interpreter.Executor): """ def __init__(self, mod, ctx, target): + assert mod is not None self.mod = mod self.ctx = ctx self.target = target def _make_executor(self, expr=None): - if not expr: - assert self.mod, "either expr or self.mod should be not null." - expr = self.mod[self.mod.entry_func] - ret_type = ir_pass.infer_type(expr).ret_type + if expr: + self.mod[self.mod.entry_func] = expr + ret_type = self.mod[self.mod.entry_func].checked_type.ret_type num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 - graph_json, mod, params = build(expr, target=self.target) + graph_json, mod, params = build(self.mod, target=self.target) gmodule = _graph_rt.create(graph_json, mod, self.ctx) if params: gmodule.set_input(**params) def _graph_wrapper(*args, **kwargs): - args = self._convert_args(expr, args, kwargs) + args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs) # Create map of inputs. for i, arg in enumerate(args): gmodule.set_input(i, arg) @@ -280,6 +279,8 @@ def create_executor(kind="debug", target : :py:class:`tvm.Target` The corresponding context """ + if mod is None: + mod = _Module() if ctx is not None: assert ctx.device_type == _nd.context(str(target), 0).device_type else: diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 138dfa8822154..1a5e82269a963 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -33,7 +33,7 @@ class Module(RelayNode): Parameters ---------- - functions : dict, optional. + functions: Optional[dict]. Map of global var to Function """ def __init__(self, functions=None, type_definitions=None): @@ -100,7 +100,7 @@ def __getitem__(self, var): Parameters ---------- - var: str or GlobalVar + var: Union[String, GlobalVar, GlobalTypeVar] The name or global variable. Returns diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py index 07116cd5faf50..7374ab94dde40 100644 --- a/tests/python/relay/test_autotvm_task_extraction.py +++ b/tests/python/relay/test_autotvm_task_extraction.py @@ -40,8 +40,8 @@ def test_task_extraction(): net, params, input_shape = get_network('resnet-18', batch_size=1) tasks = autotvm.task.extract_from_program(net, target=target, - params=params, - ops=(relay.op.nn.conv2d,)) + params=params, + ops=(relay.op.nn.conv2d,)) assert len(tasks) == 12 net, params, input_shape = get_network('resnet-18', batch_size=1) diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index ca4619c978860..f493a9b3f537d 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -57,7 +57,7 @@ def test_compile_placeholder_bypass(): result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)]) func = relay.Function(relay.ir_pass.free_vars(result), result) with relay.build_config(opt_level=0): - graph, lib, params = relay.build(func, 'llvm') + graph, lib, params = relay.build(relay.Module.from_expr(func), 'llvm') def test_compile_injective_with_tuple(): @@ -66,7 +66,7 @@ def test_compile_injective_with_tuple(): x_transpose = relay.transpose(x) output = relay.Tuple([x_transpose, y]) func = relay.Function([x, y], output) - relay.build(func, 'llvm') + relay.build(relay.Module.from_expr(func), 'llvm') def test_compile_tuple_dup(): @@ -74,7 +74,7 @@ def test_compile_tuple_dup(): log = relay.log(x) output = relay.Tuple([log, log]) f = relay.Function([x], output) - relay.build(f, 'llvm') + relay.build(relay.Module.from_expr(f), 'llvm') if __name__ == "__main__": diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 32687a4e80130..18e01e39ea276 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -101,7 +101,7 @@ def test_with_params(): x_data = np.random.rand(10, 5).astype('float32') y_data = np.random.rand(1, 5).astype('float32') params = {"y": y_data} - graph, lib, params = relay.build(func, "llvm", params=params) + graph, lib, params = relay.build(relay.Module.from_expr(func), "llvm", params=params) mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) mod.set_input(**params) mod.set_input(x=x_data) @@ -170,7 +170,7 @@ def unit_numpy(X, W): for target, ctx in ctx_list(): with relay.build_config(opt_level=2): - graph, lib, params = relay.build(z, target) + graph, lib, params = relay.build(relay.Module.from_expr(z), target) m = graph_runtime.create(graph, lib, ctx) m.set_input("X", tvm.nd.array(x.astype(dtype))) m.set_input("y", tvm.nd.array(y.astype(dtype))) diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index affc6ce04c6b1..e2db81dab1c9f 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -43,7 +43,7 @@ def test_basic_build(): targets = { tvm.expr.IntImm("int32", ctx.device_type): tgt } - g_json, mmod, params = relay.build(func, targets, "llvm", params=params) + g_json, mmod, params = relay.build(relay.Module.from_expr(func), targets, "llvm", params=params) # test rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) @@ -115,7 +115,7 @@ def check_conversion(tgt, ctx): # build with relay.build_config(opt_level=1): - g_json, mmod, params = relay.build(func, tgt) + g_json, mmod, params = relay.build(relay.Module.from_expr(func), tgt) # test rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 6d6781046a10d..0ecbfe6b4d4ac 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -342,6 +342,9 @@ def expected(dim): assert relay.ir_pass.alpha_equal(zz, after) +fuse0 = relay.transform.FuseOps(fuse_opt_level=0) +fuse2 = relay.transform.FuseOps(fuse_opt_level=2) + def test_tuple_intermediate(): def before(x): inj = relay.squeeze(x) @@ -363,16 +366,12 @@ def expected(p0): dshape = (1, 16, 64, 64) x = relay.var("x", shape=dshape) - z = before(x) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - relay.build(zz, 'llvm') - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) + orig = before(x) + fuse0(relay.Module.from_expr(orig)) + m = fuse2(relay.Module.from_expr(orig)) + relay.build(m, 'llvm') after = relay.ir_pass.infer_type(expected(x)) - assert relay.ir_pass.alpha_equal(zz, after) + assert relay.ir_pass.alpha_equal(m[m.entry_func], after) def test_tuple_consecutive(): @@ -422,16 +421,12 @@ def expected(dshape): dshape = (1, 16, 64, 64) x = relay.var("x", shape=dshape) - z = before(x) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - relay.build(zz, 'llvm') - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) + orig = before(x) + fuse0(relay.Module.from_expr(orig)) + m = fuse2(relay.Module.from_expr(orig)) + relay.build(m, 'llvm') after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(zz, after) + assert relay.ir_pass.alpha_equal(m[m.entry_func], after) def test_inception_like(): @@ -493,16 +488,12 @@ def expected(dshape): return relay.Function(relay.ir_pass.free_vars(out), out) dshape = (1, 16, 64, 64) - z = before(dshape) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - relay.build(zz, 'llvm') - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) + orig = before(dshape) + fuse0(relay.Module.from_expr(orig)) + m = fuse2(relay.Module.from_expr(orig)) + relay.build(m, 'llvm') after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(zz, after) + assert relay.ir_pass.alpha_equal(m[m.entry_func], after) def test_fuse_parallel_injective(): From 264660471193cf7b062dbf945678e0bbd06a5144 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 1 Jul 2019 21:16:12 -0700 Subject: [PATCH 06/26] [Runtime] Android argsort support (#3472) * Add contrib sort functions to android rpc app. * replaced tab with spaces oops. --- apps/android_rpc/app/src/main/jni/Application.mk | 4 ++++ apps/android_rpc/app/src/main/jni/make/config.mk | 3 +++ apps/android_rpc/app/src/main/jni/tvm_runtime.h | 4 ++++ 3 files changed, 11 insertions(+) diff --git a/apps/android_rpc/app/src/main/jni/Application.mk b/apps/android_rpc/app/src/main/jni/Application.mk index aef7629990c2f..548b69160b174 100644 --- a/apps/android_rpc/app/src/main/jni/Application.mk +++ b/apps/android_rpc/app/src/main/jni/Application.mk @@ -23,3 +23,7 @@ ifeq ($(USE_VULKAN), 1) APP_CPPFLAGS += -DTVM_VULKAN_RUNTIME=1 APP_LDFLAGS += -lvulkan endif + +ifeq ($(USE_SORT), 1) + APP_CPPFLAGS += -DUSE_SORT=1 +endif diff --git a/apps/android_rpc/app/src/main/jni/make/config.mk b/apps/android_rpc/app/src/main/jni/make/config.mk index c40ce4ba3ec7d..f61811bd604e4 100644 --- a/apps/android_rpc/app/src/main/jni/make/config.mk +++ b/apps/android_rpc/app/src/main/jni/make/config.mk @@ -22,6 +22,9 @@ USE_OPENCL = 0 # whether to enable Vulkan during compile USE_VULKAN = 0 +# whether to enable contrib sort functions during compile +USE_SORT = 1 + ifeq ($(USE_VULKAN), 1) # Statically linking vulkan requires API Level 24 or higher APP_PLATFORM = android-24 diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 60b41baaf8e70..aadc4d1884307 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -66,6 +66,10 @@ #include "../src/runtime/vulkan/vulkan_module.cc" #endif +#ifdef USE_SORT +#include "../src/contrib/sort/sort.cc" +#endif + #include From 0af5c21614eb4d4c698e520921d0b158b759774f Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 2 Jul 2019 00:20:26 -0700 Subject: [PATCH 07/26] [Codegen] Support broadcast op with symbolic shape (#3389) * [Codegen] Support broadcast op with symbolic shape * fix case where last dim = 1 * use enum; simplify stride calculation; improve doc * fix lint * improve py doc --- include/tvm/buffer.h | 15 +++++++---- python/tvm/api.py | 33 +++++++++++++++++++++-- src/api/api_lang.cc | 8 +++++- src/codegen/build_module.cc | 2 +- src/lang/buffer.cc | 15 ++++++++--- src/pass/arg_binder.cc | 15 +++++++++++ src/pass/inject_copy_intrin.cc | 4 +-- src/pass/storage_flatten.cc | 2 +- tests/python/unittest/test_lang_buffer.py | 30 +++++++++++++++++++++ topi/include/topi/detail/extern.h | 2 +- 10 files changed, 110 insertions(+), 16 deletions(-) diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index ed4ac5ea6a63f..1233e9b0b89b8 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -36,10 +36,11 @@ namespace tvm { // Internal node container Buffer class BufferNode; -/*! \brief memory access kind */ -enum class AccessMask : int { - kRead = 1, - kWrite = 2 +/*! \brief buffer type */ +enum BufferType : int { + kDefault = 1, + // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. + kAutoBroadcast = 2, }; /*! @@ -129,6 +130,8 @@ class BufferNode : public Node { * elem_offset is guaranteed to be multiple of offset_factor. */ int offset_factor; + /*! \brief buffer type */ + BufferType buffer_type; /*! \brief constructor */ BufferNode() {} @@ -142,6 +145,7 @@ class BufferNode : public Node { v->Visit("scope", &scope); v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); + v->Visit("buffer_type", &buffer_type); } /*! \return preferred index type for this buffer node */ @@ -159,7 +163,8 @@ class BufferNode : public Node { std::string name, std::string scope, int data_alignment, - int offset_factor); + int offset_factor, + BufferType buffer_type); static constexpr const char* _type_key = "Buffer"; TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); diff --git a/python/tvm/api.py b/python/tvm/api.py index d88f06170543c..e4777b6e39649 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -531,7 +531,8 @@ def decl_buffer(shape, elem_offset=None, scope="", data_alignment=-1, - offset_factor=0): + offset_factor=0, + buffer_type=""): """Declare a new symbolic buffer. Normally buffer is created automatically during lower and build. @@ -574,11 +575,39 @@ def decl_buffer(shape, If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None. + buffer_type: str, optional, {"", "auto_broadcast"} + auto_broadcast buffer allows one to implement broadcast computation + without considering whether dimension size equals to one. + TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. + Returns ------- buffer : Buffer The created buffer + Example + ------- + Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation, + + .. code-block:: python + + m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") + n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") + o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2") + A = tvm.placeholder((m0, m1, m2), name='A') + B = tvm.placeholder((n0, n1, n2), name='B') + C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') + Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="broadcast") + Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="broadcast") + s = tvm.create_schedule(C.op) + fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) + ctx = tvm.cpu(0) + a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) + fadd(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) + Note ---- Buffer data structure reflects the DLTensor structure in dlpack. @@ -602,7 +631,7 @@ def decl_buffer(shape, data = var(name, "handle") return _api_internal._Buffer( data, dtype, shape, strides, elem_offset, name, scope, - data_alignment, offset_factor) + data_alignment, offset_factor, buffer_type) def layout(layout_str): """Create a layout node from a string. diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 42d60b85e375f..00ac715e8c075 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -207,7 +207,13 @@ TVM_REGISTER_API("Range") }); TVM_REGISTER_API("_Buffer") -.set_body_typed(BufferNode::make); +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size(), 10); + auto buffer_type = args[9].operator std::string(); + BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; + *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], type); + }); TVM_REGISTER_API("_BufferAccessPtr") .set_body_method(&Buffer::access_ptr); diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 6917200ff9205..c1622338174df 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -342,7 +342,7 @@ Buffer BufferWithOffsetAlignment(Array shape, } return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - data_alignment, offset_factor); + data_alignment, offset_factor, kDefault); } void GetBinds(const Array& args, diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 3e0615162a8f8..573ecffe1b08d 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -49,7 +49,8 @@ Buffer decl_buffer(Array shape, Expr(), name, "", - 0, 0); + 0, 0, + kDefault); } // Split the given expression w.r.t the add operator @@ -365,7 +366,8 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const { n->name + "_slice", n->scope, n->data_alignment, - 0); + 0, + n->buffer_type); } Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const { @@ -405,7 +407,8 @@ Buffer BufferNode::make(Var data, std::string name, std::string scope, int data_alignment, - int offset_factor) { + int offset_factor, + BufferType buffer_type) { auto n = make_node(); n->data = std::move(data); n->dtype = dtype; @@ -428,6 +431,12 @@ Buffer BufferNode::make(Var data, n->elem_offset = std::move(elem_offset); n->data_alignment = data_alignment; n->offset_factor = offset_factor; + n->buffer_type = buffer_type; + if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { + for (size_t i = 0; i < n->shape.size(); ++i) { + n->strides.push_back(tvm::var("stride")); + } + } return Buffer(n); } diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 2822393d3f75e..d93d088644388 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -242,6 +242,21 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, check = IfThenElse::make(Not::make(is_null), check, Stmt()); init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); } + } else if (buffer->buffer_type == kAutoBroadcast) { + Type stype = buffer->DefaultIndexType(); + Expr stride = make_const(stype, 1); + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + std::ostringstream field_name; + field_name << v_strides->name_hint << '[' << k << ']'; + Expr value = cast(buffer->shape[k].type(), + Load::make(tvm_shape_type, v_strides, + IntImm::make(Int(32), k), const_true(1))); + value = tvm::if_then_else(is_null, stride, value); + value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); + Bind_(buffer->strides[k], value, field_name.str(), true); + stride = Simplify(stride * buffer->shape[k]); + } } else { std::ostringstream stride_null_err_msg; stride_null_err_msg << arg_name << ".strides: expected non-null strides."; diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index a906ee3e54741..8df5fe1f77572 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -160,7 +160,7 @@ class CopyIntrinInjector : public IRMutator { store_strides[loop_var_size], store->buffer_var->name_hint, GetStorageScope(store->buffer_var.get()), - 0, 0); + 0, 0, kDefault); Buffer src = BufferNode::make( Var(load->buffer_var.node_), load->type, @@ -169,7 +169,7 @@ class CopyIntrinInjector : public IRMutator { src_elem_offset, load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), - 0, 0); + 0, 0, kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); CHECK(out->defined()) << "flower function did not return correct stmt"; return true; diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 215f6d7397323..ff6b41612bf42 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -220,7 +220,7 @@ class StorageFlattener : public IRMutator { Var(key.GetName(), Handle()), op->type, shape, strides, Expr(), key.GetName(), skey.to_string(), - align, 0); + align, 0, kDefault); buf_map_[key] = e; Stmt body = this->Mutate(op->body); diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index e0bb0279c09f7..bd45eac2358a4 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm.schedule import Buffer +import numpy as np def test_buffer(): m = tvm.var('m') @@ -108,6 +109,34 @@ def assert_simplified_equal(index_simplified, index_direct): index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1)))) assert_simplified_equal(index_simplified, index_direct) +def test_buffer_broadcast(): + m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") + n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") + o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2") + + A = tvm.placeholder((m0, m1, m2), name='A') + B = tvm.placeholder((n0, n1, n2), name='B') + + C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') + + Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") + Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") + s = tvm.create_schedule(C.op) + + def check(): + if not tvm.module.enabled("llvm"): + return + fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) + ctx = tvm.cpu(0) + a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(2, 1, 1)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) + fadd(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) + + check() + + if __name__ == "__main__": test_buffer() test_buffer_access_ptr() @@ -115,3 +144,4 @@ def assert_simplified_equal(index_simplified, index_direct): test_buffer_access_ptr_extent() test_buffer_vload() test_buffer_index_merge_mult_mod() + test_buffer_broadcast() diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index ac00e52899fae..667722e465c44 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -49,7 +49,7 @@ inline Buffer DeclExternBuffer(Array shape, auto data = var(name, Handle()); auto elem_offset = Expr(); return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - -1, 0); + -1, 0, kDefault); } /*! From e3d6074a5b204940a7dcb0f50dbf679c7dc072f3 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Tue, 2 Jul 2019 09:14:52 -0700 Subject: [PATCH 08/26] Clean up pass.h (#3312) --- docs/api/python/relay/index.rst | 3 +- include/tvm/relay/{pass.h => analysis.h} | 117 +----- include/tvm/relay/transform.h | 97 ++--- nnvm/tests/python/compiler/test_to_relay.py | 9 +- .../graph_tuner/utils/traverse_graph.py | 13 +- python/tvm/autotvm/graph_tuner/utils/utils.py | 6 +- python/tvm/relay/__init__.py | 8 +- .../tvm/relay/{_ir_pass.py => _analysis.py} | 4 +- python/tvm/relay/{ir_pass.py => analysis.py} | 333 +++--------------- python/tvm/relay/backend/interpreter.py | 4 +- python/tvm/relay/expr.pyi | 4 +- python/tvm/relay/frontend/caffe2.py | 4 +- python/tvm/relay/frontend/common.py | 16 +- python/tvm/relay/frontend/coreml.py | 4 +- python/tvm/relay/frontend/darknet.py | 4 +- python/tvm/relay/frontend/keras.py | 4 +- python/tvm/relay/frontend/mxnet.py | 32 +- python/tvm/relay/frontend/onnx.py | 6 +- python/tvm/relay/frontend/tensorflow.py | 24 +- python/tvm/relay/frontend/tflite.py | 4 +- python/tvm/relay/module.py | 9 - python/tvm/relay/quantize/quantize.py | 4 +- python/tvm/relay/testing/dcgan.py | 2 +- python/tvm/relay/testing/densenet.py | 2 +- python/tvm/relay/testing/dqn.py | 2 +- python/tvm/relay/testing/inception_v3.py | 2 +- python/tvm/relay/testing/init.py | 5 +- python/tvm/relay/testing/lstm.py | 2 +- python/tvm/relay/testing/mlp.py | 2 +- python/tvm/relay/testing/mobilenet.py | 2 +- python/tvm/relay/testing/resnet.py | 2 +- python/tvm/relay/testing/squeezenet.py | 2 +- python/tvm/relay/testing/vgg.py | 2 +- python/tvm/relay/transform.py | 87 ++++- src/relay/backend/build_module.cc | 1 + src/relay/backend/compile_engine.cc | 2 +- src/relay/backend/compile_engine.h | 3 +- src/relay/backend/graph_plan_memory.cc | 2 +- src/relay/backend/interpreter.cc | 2 +- src/relay/backend/utils.h | 1 - src/relay/backend/vm/lambda_lift.cc | 2 +- src/relay/backend/vm/vm.cc | 5 +- src/relay/ir/alpha_equal.cc | 2 +- src/relay/ir/expr_functor.cc | 2 +- src/relay/ir/hash.cc | 6 +- src/relay/ir/module.cc | 52 ++- src/relay/pass/alter_op_layout.cc | 6 +- src/relay/pass/canonicalize_cast.cc | 2 +- src/relay/pass/canonicalize_ops.cc | 5 +- src/relay/pass/combine_parallel_conv2d.cc | 5 +- src/relay/pass/dead_code.cc | 3 +- src/relay/pass/device_annotation.cc | 7 +- src/relay/pass/eliminate_common_subexpr.cc | 5 +- src/relay/pass/eta_expand.cc | 20 +- src/relay/pass/feature.cc | 4 +- src/relay/pass/fold_constant.cc | 15 +- src/relay/pass/fold_scale_axis.cc | 15 +- src/relay/pass/forward_rewrite.cc | 34 +- src/relay/pass/fuse_ops.cc | 5 +- src/relay/pass/gradient.cc | 7 +- src/relay/pass/kind_check.cc | 4 +- src/relay/pass/mac_count.cc | 4 +- src/relay/pass/match_exhaustion.cc | 21 +- src/relay/pass/partial_eval.cc | 14 +- src/relay/pass/pass_manager.cc | 15 - src/relay/pass/quantize.cc | 19 +- src/relay/pass/simplify_inference.cc | 5 +- src/relay/pass/to_a_normal_form.cc | 2 +- src/relay/pass/to_graph_normal_form.cc | 1 + src/relay/pass/type_infer.cc | 7 +- src/relay/pass/type_solver.cc | 2 +- src/relay/pass/type_solver.h | 2 +- src/relay/pass/util.cc | 14 +- src/relay/pass/well_formed.cc | 4 +- tests/cpp/relay_build_module_test.cc | 3 +- tests/cpp/relay_pass_type_infer_test.cc | 10 +- tests/cpp/relay_transform_sequential.cc | 6 +- .../frontend/caffe2/model_zoo/squeezenet.py | 2 +- tests/python/frontend/caffe2/test_graph.py | 12 +- tests/python/frontend/mxnet/test_graph.py | 10 +- .../nnvm_to_relay/test_alter_conv2d.py | 13 +- tests/python/relay/test_adt.py | 22 +- .../relay/test_backend_compile_engine.py | 8 +- .../relay/test_backend_graph_runtime.py | 7 +- .../python/relay/test_backend_interpreter.py | 2 +- tests/python/relay/test_error_reporting.py | 5 +- tests/python/relay/test_feature.py | 7 +- tests/python/relay/test_ir_bind.py | 4 +- tests/python/relay/test_ir_nodes.py | 2 +- tests/python/relay/test_ir_parser.py | 2 +- tests/python/relay/test_ir_well_formed.py | 2 +- tests/python/relay/test_op_grad_level1.py | 16 +- tests/python/relay/test_op_level1.py | 46 ++- tests/python/relay/test_op_level10.py | 26 +- tests/python/relay/test_op_level2.py | 53 +-- tests/python/relay/test_op_level3.py | 66 ++-- tests/python/relay/test_op_level4.py | 21 +- tests/python/relay/test_op_level5.py | 54 +-- tests/python/relay/test_pass_alpha_equal.py | 10 +- .../python/relay/test_pass_alter_op_layout.py | 171 ++++----- tests/python/relay/test_pass_annotation.py | 128 +++---- .../relay/test_pass_canonicalize_cast.py | 7 +- tests/python/relay/test_pass_check_kind.py | 2 +- .../test_pass_combine_parallel_conv2d.py | 50 +-- .../relay/test_pass_dead_code_elimination.py | 34 +- .../test_pass_eliminate_common_subexpr.py | 22 +- tests/python/relay/test_pass_eta_expand.py | 9 +- tests/python/relay/test_pass_fold_constant.py | 64 ++-- .../python/relay/test_pass_fold_scale_axis.py | 106 +++--- tests/python/relay/test_pass_fuse_ops.py | 151 ++++---- tests/python/relay/test_pass_gradient.py | 48 ++- tests/python/relay/test_pass_mac_count.py | 35 +- tests/python/relay/test_pass_manager.py | 19 +- tests/python/relay/test_pass_partial_eval.py | 30 +- tests/python/relay/test_pass_quantize.py | 20 +- .../relay/test_pass_simplify_inference.py | 11 +- .../relay/test_pass_to_a_normal_form.py | 32 +- .../relay/test_pass_to_graph_normal_form.py | 18 +- .../python/relay/test_pass_unmatched_cases.py | 2 +- tests/python/relay/test_pass_vars.py | 6 +- tests/python/relay/test_type_infer.py | 74 ++-- tests/python/relay/test_type_solver.py | 2 +- tests/python/relay/test_typecall.py | 11 +- .../python/unittest/test_graph_tuner_core.py | 2 +- .../python/unittest/test_graph_tuner_utils.py | 8 +- tutorials/frontend/using_external_lib.py | 2 +- vta/python/vta/top/graphpack.py | 18 +- vta/scripts/tune_resnet.py | 1 - vta/tutorials/autotvm/tune_relay_vta.py | 1 - .../frontend/deploy_resnet_on_vta.py | 1 - 130 files changed, 1273 insertions(+), 1369 deletions(-) rename include/tvm/relay/{pass.h => analysis.h} (71%) rename python/tvm/relay/{_ir_pass.py => _analysis.py} (89%) rename python/tvm/relay/{ir_pass.py => analysis.py} (53%) diff --git a/docs/api/python/relay/index.rst b/docs/api/python/relay/index.rst index 39a68b6d1f5d5..90746b8e5d4ee 100644 --- a/docs/api/python/relay/index.rst +++ b/docs/api/python/relay/index.rst @@ -33,7 +33,8 @@ compiler stack. expr frontend image - ir_pass + analysis + transform module nn op diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/analysis.h similarity index 71% rename from include/tvm/relay/pass.h rename to include/tvm/relay/analysis.h index 79172c3743167..e3d16b6eda739 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/analysis.h @@ -18,42 +18,21 @@ */ /*! - * \file tvm/relay/pass.h - * \brief The set of Relay passes written in C++. - */ -#ifndef TVM_RELAY_PASS_H_ -#define TVM_RELAY_PASS_H_ + * \file tvm/relay/analysis.h + * \brief The set of Relay analysis passes written in C++. + */ +#ifndef TVM_RELAY_ANALYSIS_H_ +#define TVM_RELAY_ANALYSIS_H_ -#include -#include +#include #include #include -#include #include -#include -#include -#include #include -#include namespace tvm { namespace relay { -/*! - * \brief Infer the type of an expression. - * - * The result of type checking is a new expression with unambigous - * type information filled in, as well as it's checked type field - * populated with the result type. - * - * \param expr The expression to type check. - * \param mod The module used for referencing global functions, can be - * None. - * - * \return A type checked expression with its checked_type field populated. - */ -TVM_DLL Expr InferType(const Expr& expr, const Module& mod); - /*! * \brief Infer the type of a function as if it is mapped to var in the mod. * @@ -64,7 +43,8 @@ TVM_DLL Expr InferType(const Expr& expr, const Module& mod); * \return A type checked Function with its checked_type field populated. * \note this function mutates mod and is not thread-safe. */ -TVM_DLL Function InferType(const Function& f, const Module& mod, +TVM_DLL Function InferType(const Function& f, + const Module& mod, const GlobalVar& var); /*! @@ -271,58 +251,6 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); -/*! - * \brief Fold constant expressions. - * - * \param expr the expression to be optimized. - * - * \return The optimized expression. - */ -TVM_DLL Expr FoldConstant(const Expr& expr); - -/*! - * \brief Fuse operations into expr into seperate functions. - * - * \param expr The expression. - * \param fuse_opt_level Optimization level. - * \param mod the module. - * - * \return The optimized expression. - */ -TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod); - -/*! - * \brief Apply rewrite rules to rewrite the expr in post DFS order. - * - * \param expr The expression. - * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite - * rule function. - * \param fcontext Additional callback to provide context argument for each call node. - * \param fmulti_ref_trigger Transformation function to be called when - * an Expr consumed by multiple callers. - * \return The rewritten expression. - */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_attr_name, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); - -/*! - * \brief Apply rewrite rules to rewrite the expr in post DFS order. - * - * \param expr The expression. - * \param rewrite_func The rewrite func that will apply to all operators. - * \param fcontext Additional callback to provide context argument for each call node. - * \param fmulti_ref_trigger Transformation function to be called when - * an Expr consumed by multiple callers. - * - * \return The rewritten expression. - */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); - /*! * \brief Rewrite the annotated program. * @@ -364,19 +292,6 @@ TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); */ TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); -/*! - * \brief Bind the free variables to a Relay expression. - * - * Parameter binding can only happen if expr is a Function. - * binds cannot change internal arguments of internal functions. - * - * \param expr The function to be binded. - * \param binds The map of arguments to - * - * \return The expression with all free vars bound. - */ -TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); - /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. @@ -388,7 +303,6 @@ struct StructuralHash { * \return the hash value. */ size_t operator()(const Type& type) const; - /*! \brief Hash a Relay expression. * * Implements structural hashing of a Relay expression. @@ -400,20 +314,7 @@ struct StructuralHash { size_t operator()(const Expr& expr) const; }; -namespace vm { - -/*! - * \brief Compile a module, and construct the virtual machine. - * - * \param mod The module to compile. - * - * \return The constructed virtual machine. - */ -runtime::vm::VirtualMachine CompileModule(const Module& mod); - -} // namespace vm - } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_H_ +#endif // TVM_RELAY_ANALYSIS_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 9ae71d824f94e..bb8638abbabf1 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -378,36 +378,6 @@ TVM_DLL Pass FoldConstant(); */ TVM_DLL Pass FuseOps(int fuse_opt_level = -1); -/*! - * \brief Apply rewrite rules to rewrite the expr in post DFS order. - * - * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite - * rule function. - * \param fcontext Additional callback to provide context argument for each call node. - * \param fmulti_ref_trigger Transformation function to be called when - * an Expr consumed by multiple callers. - * - * \return The pass. - */ -TVM_DLL Pass ForwardRewrite(const std::string& rewrite_map_attr_name, - std::function fcontext = nullptr, - std::function - fmulti_ref_trigger = nullptr); - -/*! - * \brief Apply rewrite rules to rewrite the expr in post DFS order. - * - * \param rewrite_func The rewrite func that will apply to all operators. - * \param fcontext Additional callback to provide context argument for each call node. - * \param fmulti_ref_trigger Transformation function to be called when - * an Expr consumed by multiple callers. - * - * \return The pass. - */ -TVM_DLL Pass ForwardRewrite(const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); - /*! * \brief Rewrite the annotated program. * @@ -554,21 +524,68 @@ TVM_DLL Pass CanonicalizeCast(); */ TVM_DLL Pass EtaExpand(); +} // namespace transform + /*! - * \brief This is a helper function that runs a some optimization passes on - * a certain expression and returns the optimized version. With the help of this - * function, users don't need to manually construct a module, then perform - * passes, and finally and extract the target function/expression from the - * returned module frequently. + * \brief Bind the free variables to a Relay expression. This is a helper + * function usually called by other pass functions to help optimizations. * - * \param expr The expression to be optimized. - * \param passes The passses that will be applied on the given expression. + * \param expr The input expression. + * \param binds The variable to expression map that will be used to help the + * binding. * - * \return The optimized expression. + * \return The updated expression. */ -TVM_DLL Expr OptimizeOnExpr(const Expr& expr, const Array& passes); +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); + +/*! + * \brief Infer the type of a function as if it is mapped to var in the mod. + * + * \param f the function. + * \param mod The module used for referencing global functions. + * \param var The global variable corresponding to the function. + * + * \return A type checked Function with its checked_type field populated. + * \note this function mutates mod and is not thread-safe. + */ +TVM_DLL Function InferType(const Function& f, + const Module& mod, + const GlobalVar& var); + +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. This + * function is used as a helper function to rewrtie an expression in a pass. + * + * \param expr The expression. + * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite + * rule function. + * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. + * \return The rewritten expression. + */ +TVM_DLL Expr ForwardRewrite(const Expr& expr, + const std::string& rewrite_map_attr_name, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); + +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. This + * function is used as a helper function to rewrtie an expression in a pass. + * + * \param expr The expression. + * \param rewrite_func The rewrite func that will apply to all operators. + * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. + * + * \return The rewritten expression. + */ +TVM_DLL Expr ForwardRewrite(const Expr& expr, + const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); -} // namespace transform } // namespace relay } // namespace tvm diff --git a/nnvm/tests/python/compiler/test_to_relay.py b/nnvm/tests/python/compiler/test_to_relay.py index e79831d06cf26..dac14a8c1f220 100644 --- a/nnvm/tests/python/compiler/test_to_relay.py +++ b/nnvm/tests/python/compiler/test_to_relay.py @@ -18,7 +18,7 @@ from nnvm import testing from nnvm import to_relay import tvm -from tvm.relay import ir_pass +from tvm.relay import transform from tvm.relay import create_executor from tvm.contrib import graph_runtime import numpy as np @@ -41,10 +41,11 @@ def check_model(sym, shapes, dtypes, params): nnvm_rts.run(**inputs) nnvm_out = nnvm_rts.get_output(0) relay_model, params = to_relay.to_relay(net, shapes, dtypes, params) - relay_model = ir_pass.infer_type(relay_model) - relay_rts = create_executor(kind='graph', ctx=tvm.cpu(0), target='llvm') + mod = tvm.relay.Module.from_expr(relay_model) + mod = transform.InferType()(mod) + relay_rts = create_executor(kind='graph', mod=mod, ctx=tvm.cpu(0), target='llvm') inputs.update(params) - relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values())) + relay_out = relay_rts.evaluate()(*list(inputs.values())) np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy()) # def test_mlp(): diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 7e7f1749eae73..62e409fec1a01 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -21,6 +21,7 @@ import topi from tvm import relay, autotvm +from tvm.relay import transform from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple from tvm.relay.ty import TupleType, TensorType from tvm.autotvm.task import TaskExtractEnv @@ -80,6 +81,14 @@ def expr2graph(expr, target_ops, node_dict, node_list): task_pos += 1 +def _infer_type(node): + """A method to infer the type of a relay expression.""" + mod = relay.Module.from_expr(node) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(node, relay.Function) else entry.body + + def _expr2graph_impl(expr, target_ops, node_dict, node_list): """Implementation to convert relay expr to graph data structure """ @@ -99,7 +108,7 @@ def _traverse_expr(node): node_entry["inputs"] += node_list[in_node_idx]["inputs"] else: node_entry["inputs"].append([in_node_idx, 0, 0]) - infer_out = relay.ir_pass.infer_type(node) + infer_out = _infer_type(node) out_type = infer_out._checked_type_ if isinstance(out_type, TensorType): node_entry["types"].append(out_type) @@ -168,7 +177,7 @@ def _traverse_expr(node): node_dict[node] = node_index node_list.append(node_entry) - relay.ir_pass.post_order_visit(expr, _traverse_expr) + relay.analysis.post_order_visit(expr, _traverse_expr) def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_names): diff --git a/python/tvm/autotvm/graph_tuner/utils/utils.py b/python/tvm/autotvm/graph_tuner/utils/utils.py index 6151734299af6..797a38ae36983 100644 --- a/python/tvm/autotvm/graph_tuner/utils/utils.py +++ b/python/tvm/autotvm/graph_tuner/utils/utils.py @@ -17,6 +17,7 @@ # pylint: disable=eval-used,invalid-name,too-many-arguments """Utility functions""" from tvm import relay +from tvm.relay import transform def has_multiple_inputs(node_list, node_idx, input_names): @@ -107,4 +108,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): rebind_dict[var] = updated_input_dict[var.name_hint] updated_expr = relay.expr.bind(expr, rebind_dict) - return relay.ir_pass.infer_type(updated_expr) + mod = relay.Module.from_expr(updated_expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(updated_expr, relay.Function) else entry.body diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 5536e503e6b67..dfac85bb1ed28 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -24,7 +24,7 @@ from . import expr_functor from . import module from . import adt -from . import ir_pass +from . import analysis from . import transform from .build_module import build, create_executor from .transform import build_config @@ -32,6 +32,7 @@ from . import parser from . import debug from . import param_dict +from . import feature # Root operators from .op import Op @@ -101,7 +102,7 @@ bind = expr.bind module_pass = transform.module_pass function_pass = transform.function_pass -alpha_equal = ir_pass.alpha_equal +alpha_equal = analysis.alpha_equal # ExprFunctor ExprFunctor = expr_functor.ExprFunctor @@ -122,3 +123,6 @@ ModulePass = transform.ModulePass FunctionPass = transform.FunctionPass Sequential = transform.Sequential + +# Feature +Feature = feature.Feature diff --git a/python/tvm/relay/_ir_pass.py b/python/tvm/relay/_analysis.py similarity index 89% rename from python/tvm/relay/_ir_pass.py rename to python/tvm/relay/_analysis.py index 3a0e0ac846b99..32a7324ae29f5 100644 --- a/python/tvm/relay/_ir_pass.py +++ b/python/tvm/relay/_analysis.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI exposing the Relay type inference and checking.""" +"""FFI exposing the passes for Relay program analysis.""" from tvm._ffi.function import _init_api -_init_api("relay._ir_pass", __name__) +_init_api("relay._analysis", __name__) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/analysis.py similarity index 53% rename from python/tvm/relay/ir_pass.py rename to python/tvm/relay/analysis.py index 52dc34d7aac9d..ee8ce985fcbc0 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/analysis.py @@ -20,7 +20,7 @@ This file contains the set of passes for Relay, which exposes an interface for configuring the passes and scripting them in Python. """ -from . import _ir_pass +from . import _analysis from . import _make from .expr import Expr from .ty import Type @@ -41,71 +41,7 @@ def post_order_visit(expr, fvisit): fvisit : function The visitor function to be applied. """ - return _ir_pass.post_order_visit(expr, fvisit) - -def infer_type(expr, mod=None): - """Infer the type of expr under the context of mod. - - Parameters - ---------- - expr: tvm.relay.Expr - The input expression. - - mod: Optional[tvm.relay.Module] - The global module. - - Returns - ------- - checked_expr : tvm.relay.Expr - The checked expression. - """ - return _ir_pass.infer_type(expr, mod) - - -def backward_fold_scale_axis(expr): - """Backward fold axis scaling into weights of conv2d/dense. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression, we expect that expr's types - should be fully inferred by infer_type. - - Returns - ------- - folded_expr : tvm.relay.Expr - The folded expression after transformation. - - Note - ---- - It is recommended to call backward_fold_scale_axis - before using forward_fold_scale_axis. - As backward folding targets common conv-bn pattern. - """ - return _ir_pass.backward_fold_scale_axis(expr) - - -def forward_fold_scale_axis(expr): - """Fold the scaling of axis into weights of conv2d/dense. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression, we expect that expr's types - should be fully inferred by infer_type. - - Returns - ------- - folded_expr : tvm.relay.Expr - The folded expression after transformation. - - Note - ---- - It is recommended to call backward_fold_scale_axis - before using forward_fold_scale_axis. - As backward folding targets common conv-bn pattern. - """ - return _ir_pass.forward_fold_scale_axis(expr) + return _analysis.post_order_visit(expr, fvisit) def well_formed(expr): @@ -121,12 +57,13 @@ def well_formed(expr): well_form : bool Whether the input expression is well formed """ - return _ir_pass.well_formed(expr) + return _analysis.well_formed(expr) def check_kind(t, mod=None): """Check that the type is well kinded and return the kind. - For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes. + For example, this mean type cannot has tensor of tensor, or is a tuple type + of 2 shapes. Parameters ---------- @@ -149,9 +86,9 @@ def check_kind(t, mod=None): assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type """ if mod is not None: - return _ir_pass.check_kind(t, mod) + return _analysis.check_kind(t, mod) else: - return _ir_pass.check_kind(t) + return _analysis.check_kind(t) def free_vars(expr): @@ -173,7 +110,7 @@ def free_vars(expr): neural networks: usually this means weights of previous are ordered first. """ - return _ir_pass.free_vars(expr) + return _analysis.free_vars(expr) def bound_vars(expr): @@ -189,7 +126,7 @@ def bound_vars(expr): free : List[tvm.relay.Var] The list of bound variables in post-DFS order. """ - return _ir_pass.bound_vars(expr) + return _analysis.bound_vars(expr) def all_vars(expr): @@ -205,7 +142,7 @@ def all_vars(expr): free : List[tvm.relay.Var] The list of all variables in post-DFS order. """ - return _ir_pass.all_vars(expr) + return _analysis.all_vars(expr) def free_type_vars(expr, mod=None): @@ -225,7 +162,7 @@ def free_type_vars(expr, mod=None): The list of free type variables in post-DFS order """ use_mod = mod if mod is not None else Module() - return _ir_pass.free_type_vars(expr, use_mod) + return _analysis.free_type_vars(expr, use_mod) def bound_type_vars(expr, mod=None): @@ -245,7 +182,7 @@ def bound_type_vars(expr, mod=None): The list of bound type variables in post-DFS order """ use_mod = mod if mod is not None else Module() - return _ir_pass.bound_type_vars(expr, use_mod) + return _analysis.bound_type_vars(expr, use_mod) def all_type_vars(expr, mod=None): @@ -255,6 +192,7 @@ def all_type_vars(expr, mod=None): ---------- expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type + mod : Optional[tvm.relay.Module] The global module @@ -264,41 +202,7 @@ def all_type_vars(expr, mod=None): The list of all type variables in post-DFS order """ use_mod = mod if mod is not None else Module() - return _ir_pass.all_type_vars(expr, use_mod) - - -def simplify_inference(expr): - """ Simplify the data-flow graph for inference phase. - - Parameters - ---------- - expr : tvm.relay.Expr - The input Expression - - Returns - ------- - result : tvm.relay.Expr - An expression which is semantically equal to the input expression, - but with some simplification - """ - return _ir_pass.simplify_inference(expr) - - -def canonicalize_ops(expr): - """ Canonicalize special operators to basic operators. - This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) - - Parameters - ---------- - expr : tvm.relay.Expr - The input Expression - - Returns - ------- - result : tvm.relay.Expr - An expression without bias_add - """ - return _ir_pass.canonicalize_ops(expr) + return _analysis.all_type_vars(expr, use_mod) def alpha_equal(lhs, rhs): @@ -342,128 +246,6 @@ def graph_equal(lhs, rhs): return bool(_make._graph_equal(lhs, rhs)) -def structural_hash(value): - """Hash a Relay expression structurally. - - Parameters - ---------- - expr : Union[tvm.relay.Expr, tvm.relay.Type] - The expression to hash. - - Returns - ------- - result : int - The hash value - """ - if isinstance(value, Expr): - return int(_ir_pass._expr_hash(value)) - elif isinstance(value, Type): - return int(_ir_pass._type_hash(value)) - else: - msg = ("found value of type {0} expected" + - "relay.Expr or relay.Type").format(type(value)) - raise TypeError(msg) - - -def fold_constant(expr): - """Fold the constant expression in expr. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - transformed_expr : tvm.relay.Expr - The transformed expression. - """ - return _ir_pass.FoldConstant(expr) - - -def fuse_ops(expr, opt_level=1, mod=None): - """Fuse operators in expr together. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - opt_level : int - The level of fuse optimization. - - mod : tvm.relay.Module - The module to perform fusion over. - - Returns - ------- - transformed_expr : tvm.relay.Expr - Transformed expression, containing fused result. - """ - return _ir_pass.FuseOps(expr, opt_level, mod) - - -def combine_parallel_conv2d(expr, min_num_branches=3): - """Combine multiple conv2d into one. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - min_num_branches : int - The minimum number of parallel branches when the transformation should be applied. - - Returns - ------- - transformed_expr : tvm.relay.Expr - Transformed expression - """ - return _ir_pass.CombineParallelConv2D(expr, min_num_branches) - - -def alter_op_layout(expr): - """Alternate the layouts of operators or replace primitive operators with - other expressions. - This pass can be used for computing convolution in custom layouts or - other general weight pre-transformation. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - transformed_expr : tvm.relay.Expr - Transformed expression with alternated layout. - """ - return _ir_pass.AlterOpLayout(expr) - - -def rewrite_annotated_ops(expr, fallback_device): - """Rewrite the annotated program where annotation operators, e.g. - `on_deivce`, mark which device an expression should be scheduled to. - This pass helps heterogeneous execution where different operators may need - to be allocated on various devices. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - fallback_device : int - The fallback device type. It is also used as the default device for - operators with no annotated device. - - Returns - ------- - transformed_expr : tvm.relay.Expr - Transformed expression with cross device data copy operators. - """ - return _ir_pass.RewriteDeviceAnnotation(expr, fallback_device) - - def collect_device_info(expr): """Collect the device allocation map for the given expression. The device ids are propagated from the `device_copy` operators. @@ -478,7 +260,7 @@ def collect_device_info(expr): ret : Dict[tvm.relay.expr, int] A dictionary mapping tvm.relay.Expr to device type. """ - return _ir_pass.CollectDeviceInfo(expr) + return _analysis.CollectDeviceInfo(expr) def collect_device_annotation_ops(expr): @@ -495,38 +277,7 @@ def collect_device_annotation_ops(expr): A dictionary mapping tvm.relay.Expr to device type where the keys are annotation expressions. """ - return _ir_pass.CollectDeviceAnnotationOps(expr) - - -def gradient(expr, mod=None, mode='higher_order'): - """ - Transform the input function, - returning a function that calculate the original result, - paired with gradient of the input. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression, which is a Function or a GlobalVar. - - mod : Optional[tvm.relay.Module] - - mode : Optional[String] - The mode of the automatic differentiation algorithm. - 'first_order' only work on first order code, but will not produce reference nor closure. - 'higher_order' work on all code using reference and closure. - - Returns - ------- - expr : tvm.relay.Expr - The transformed expression. - """ - if mode == 'first_order': - return _ir_pass.first_order_gradient(expr, mod) - elif mode == 'higher_order': - return _ir_pass.gradient(expr, mod) - else: - raise Exception('unknown mode') + return _analysis.CollectDeviceAnnotationOps(expr) def get_total_mac_number(expr): @@ -543,27 +294,7 @@ def get_total_mac_number(expr): result : int64 The number of MACs (multiply-accumulate) of a model """ - return _ir_pass.GetTotalMacNumber(expr) - - -def eliminate_common_subexpr(expr, fskip=None): - """ - Eliminate common subexpressions. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - fskip : function - The callback function that decides whether an expression should be skipped. - - Returns - ------- - result : tvm.relay.Expr - The output expression. - """ - return _ir_pass.eliminate_common_subexpr(expr, fskip) + return _analysis.GetTotalMacNumber(expr) def unmatched_cases(match, mod=None): @@ -574,15 +305,16 @@ def unmatched_cases(match, mod=None): ---------- match : tvm.relay.Match The match expression + mod : Optional[tvm.relay.Module] The module (defaults to an empty module) Returns ------- missing_patterns : [tvm.relay.Pattern] - Patterns that the match expression does not catch. + Patterns that the match expression does not catch. """ - return _ir_pass.unmatched_cases(match, mod) + return _analysis.unmatched_cases(match, mod) def detect_feature(a, b=None): @@ -605,4 +337,27 @@ def detect_feature(a, b=None): """ if isinstance(a, Module): a, b = b, a - return set([Feature(int(x)) for x in _ir_pass.detect_feature(a, b)]) + return set([Feature(int(x)) for x in _analysis.detect_feature(a, b)]) + + +def structural_hash(value): + """Hash a Relay expression structurally. + + Parameters + ---------- + expr : Union[tvm.relay.Expr, tvm.relay.Type] + The expression to hash. + + Returns + ------- + result : int + The hash value + """ + if isinstance(value, Expr): + return int(_analysis._expr_hash(value)) + elif isinstance(value, Type): + return int(_analysis._type_hash(value)) + else: + msg = ("found value of type {0} expected" + + "relay.Expr or relay.Type").format(type(value)) + raise TypeError(msg) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index cf643f61243cc..5b7d9eda46b49 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -21,7 +21,7 @@ import numpy as np from . import _backend -from .. import _make, ir_pass, transform +from .. import _make, analysis, transform from .. import module from ... import register_func, nd from ..base import NodeBase, register_relay_node @@ -239,7 +239,7 @@ def evaluate(self, expr=None, binds=None): return self._make_executor() if isinstance(expr, Function): - assert not ir_pass.free_vars(expr) + assert not analysis.free_vars(expr) if isinstance(expr, (Function, GlobalVar)): return self._make_executor(expr) diff --git a/python/tvm/relay/expr.pyi b/python/tvm/relay/expr.pyi index b7395c365390a..d264e99e05770 100644 --- a/python/tvm/relay/expr.pyi +++ b/python/tvm/relay/expr.pyi @@ -19,7 +19,7 @@ from typing import List import tvm from .base import Span, NodeBase from .ty import Type, TypeParam -from ._ir_pass import _get_checked_type +from ._analysis import _get_checked_type class Expr(NodeBase): @@ -128,4 +128,4 @@ class If(Expr): def __init__(self, cond, true_value, false_value): # type: (Expr, Expr, Expr) -> None - ... \ No newline at end of file + ... diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index 18489b380ee71..91f0409b39d5f 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -18,7 +18,7 @@ """Caffe2 frontend""" from __future__ import absolute_import as _abs import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .. import op as _op @@ -450,7 +450,7 @@ def from_caffe2(self, init_net, predict_net): else: outputs = out[0] - func = _expr.Function(ir_pass.free_vars(outputs), outputs) + func = _expr.Function(analysis.free_vars(outputs), outputs) self._mod[self._mod.entry_func] = func return self._mod, self._params diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index efd198803c2b6..6d8e14569e73c 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -19,8 +19,8 @@ import logging from topi.util import get_const_tuple from .. import expr as _expr -from .. import expr as _expr -from .. import ir_pass +from .. import module as _module +from .. import transform as _transform from .. import op as _op @@ -407,9 +407,17 @@ def get_name(node): name = node.name_hint return name + +def infer_type(node): + """A method to infer the type of an intermediate node in the relay graph.""" + mod = _module.Module.from_expr(node) + mod = _transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(node, _expr.Function) else entry.body + def infer_shape(inputs): """A method to get the output shape of an intermediate node in the graph.""" - out_type = ir_pass.infer_type(inputs) + out_type = infer_type(inputs) out_shapes = get_const_tuple(out_type.checked_type.shape) return out_shapes @@ -417,7 +425,7 @@ def infer_channels(inputs, transpose=False): """A hack for getting 'channels' or 'units' since caffe2 does not provide these attributes. We check the shape of weights provided to get the number. """ - out_type = ir_pass.infer_type(inputs) + out_type = infer_type(inputs) out_shapes = [get_const_tuple(out_type.checked_type.shape)] channels = out_shapes[0][0] if not transpose else out_shapes[0][1] return channels diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 1cac547d07c95..e7b129e66724b 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -19,7 +19,7 @@ from __future__ import absolute_import as _abs import numpy as np import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .. import op as _op @@ -462,6 +462,6 @@ def from_coreml(model, shape=None): for o in spec.description.output] # for now return first output outexpr = outexpr[0] - func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) + func = _expr.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 7b26ed5692df7..f452146ae46cc 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -23,7 +23,7 @@ from enum import Enum import numpy as np import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .common import get_relay_op, new_var @@ -820,7 +820,7 @@ def from_darknet(self): outputs = _as_list(sym) + self._outs outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - sym = _expr.Function(ir_pass.free_vars(outputs), outputs) + sym = _expr.Function(analysis.free_vars(outputs), outputs) return _module.Module.from_expr(sym), self._tvmparams def from_darknet(net, diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index ad033f9bf3260..91da87c84b809 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -20,7 +20,7 @@ import sys import numpy as np import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .. import op as _op @@ -743,6 +743,6 @@ def _convert_input_layer(keras_layer): outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \ for oc in model._output_coordinates] outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) - func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) + func = _expr.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0bcee63ad3e8c..26c357e9c9244 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -20,7 +20,7 @@ import json import tvm -from .. import ir_pass +from .. import analysis, transform from .. import expr as _expr from .. import op as _op from .. import module as _module @@ -41,6 +41,13 @@ "relu" : _op.nn.relu } +def _infer_type(node): + """A method to infer the type of an intermediate node in the relay graph.""" + mod = _module.Module.from_expr(node) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(node, _expr.Function) else entry.body + def _mx_fully_connected(inputs, attrs): import mxnet as mx units = attrs.get_int("num_hidden") @@ -89,7 +96,8 @@ def _stable_softrelu(x): def _mx_compare(new_op, wrapper): def impl(inputs, attrs): - dtype = ir_pass.infer_type(inputs[0]).checked_type.dtype + expr = _infer_type(inputs[0]) + dtype = expr.checked_type.dtype return wrapper(new_op)(inputs, attrs).astype(dtype) return impl @@ -258,7 +266,8 @@ def _mx_slice_like(inputs, attrs): def _mx_slice_axis(inputs, attrs): assert len(inputs) == 1 - shape = ir_pass.infer_type(inputs[0]).checked_type.shape + expr = _infer_type(inputs[0]) + shape = expr.checked_type.shape axis = attrs.get_int("axis") ax_beg = attrs.get_int("begin") ax_end = attrs.get_str("end") @@ -302,7 +311,8 @@ def _mx_crop_like(inputs, attrs): if offset == (0, 0): new_attrs["axes"] = (2, 3) return _op.slice_like(*inputs, **new_attrs) - like_shape = ir_pass.infer_type(inputs[1]).checked_type.shape + expr = _infer_type(inputs[1]) + like_shape = expr.checked_type.shape new_attrs['begin'] = [0, 0, offset[0], offset[1]] new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2], offset[1]+like_shape[3]] @@ -532,7 +542,8 @@ def _mx_resize(inputs, attrs): scale_width = attrs.get_float("scale_width", None) height = attrs.get_int("height", 1) width = attrs.get_int("width", 1) - shape = ir_pass.infer_type(inputs[0]).checked_type.shape + expr = _infer_type(inputs[0]) + shape = expr.checked_type.shape if scale_height is not None: height = (scale_height * shape[2]).astype("int32") if scale_width is not None: @@ -639,7 +650,8 @@ def _mx_broadcast_axis(inputs, attrs): assert len(axis) == len(size) if len(axis) == 0: return inputs[0] - src_shape = ir_pass.infer_type(inputs[0])._checked_type_.shape + expr = _infer_type(inputs[0]) + src_shape = expr.checked_type.shape tgt_shape = [] for i, dim in enumerate(src_shape): if i not in axis: @@ -734,7 +746,8 @@ def _rnn_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias, activati return out, [out] def _gru_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): - dtype = ir_pass.infer_type(data).checked_type.dtype + expr = _infer_type(data) + dtype = expr.checked_type.dtype i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1) h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1) i2h_r, i2h_z, i2h = _op.split(i2h, indices_or_sections=3, axis=1) @@ -776,7 +789,8 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): seq_data = inputs[0] concat_weight = inputs[1] init_states = inputs[2:] - data_shape = ir_pass.infer_type(seq_data).checked_type.shape + expr = _infer_type(seq_data) + data_shape = expr.checked_type.shape seq_len = int(data_shape[0]) assert len(concat_weight) == num_layers * 4 * direct @@ -1099,7 +1113,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None): outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(ir_pass.free_vars(outputs), outputs) + func = _expr.Function(analysis.free_vars(outputs), outputs) return func diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index bb968ec0bea8a..397ca90de55f2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -22,7 +22,7 @@ import numpy as np import tvm from ... import nd as _nd -from .. import ir_pass +from .. import analysis from .. import transform as _transform from .. import expr as _expr from .. import module as _module @@ -412,7 +412,7 @@ def _impl_v1(cls, inputs, attr, params): else: data, shape = inputs logging.warning("Constant evaluating Reshape's shape argument, may reduce performance") - shape_params = ir_pass.free_vars(shape) + shape_params = analysis.free_vars(shape) func = _expr.Function(shape_params, shape) mod = _module.Module.from_expr(func) seq = _transform.Sequential([_transform.InferType(), @@ -1106,7 +1106,7 @@ def from_onnx(self, graph, opset): # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(ir_pass.free_vars(outputs), outputs) + func = _expr.Function(analysis.free_vars(outputs), outputs) return _module.Module.from_expr(func), self._params def _parse_value_proto(self, value_proto): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d754e85ef78d7..e14566f6ab334 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,7 +27,8 @@ import tvm from topi.util import get_const_tuple -from .. import ir_pass +from .. import analysis +from .. import transform as _transform from .. import expr as _expr from .. import op as _op from ..expr_functor import ExprMutator @@ -38,9 +39,9 @@ def _infer_value(input_val, params): from tvm.contrib import graph_runtime # Check that all free variables have associated parameters. - assert all(var.name_hint in params.keys() for var in ir_pass.free_vars( + assert all(var.name_hint in params.keys() for var in analysis.free_vars( input_val)), "All inputs to infer must be available in params." - func = _expr.Function(ir_pass.free_vars(input_val), input_val) + func = _expr.Function(analysis.free_vars(input_val), input_val) with tvm.relay.build_config(opt_level=0): graph, lib, params = tvm.relay.build(func, target="llvm", params=params) ctx = tvm.context("llvm", 0) @@ -235,9 +236,16 @@ def _infer_out_shapes(inputs, params): """A method to get the output shape of intermediate nodes in the relay graph.""" return [_infer_shape(inputs, params)] +def _infer_type(node): + """A method to infer the type of an intermediate node in the relay graph.""" + mod = _module.Module.from_expr(node) + mod = _transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(node, _expr.Function) else entry.body + def _infer_shape(node, params=None): """A method to get the output shape of an intermediate node in the relay graph.""" - out_type = ir_pass.infer_type(node) + out_type = _infer_type(node) return get_const_tuple(out_type.checked_type.shape) def _get_param(params, input_node): @@ -1841,7 +1849,8 @@ def _while_loop(self): bind_map = {} for i, var in enumerate(self.loop_vars): if not isinstance(var, _expr.Var): - var_type = ir_pass.infer_type(var).checked_type + var_chk = _infer_type(var) + var_type = var_chk.checked_type else: var_type = var.type_annotation @@ -2112,7 +2121,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out.append(out_rnn) out = out[0] if len(out) == 1 else _expr.Tuple(out) - func = _expr.Function(ir_pass.free_vars(out), out) + func = _expr.Function(analysis.free_vars(out), out) self._mod[self._mod.entry_func] = func return self._mod, self._params @@ -2329,7 +2338,8 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ else: if node_name_prefix not in self._branches: self._branches[node_name_prefix] = Branch() - self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0]) + chk_op = _infer_type(op[0]) + self._branches[node_name_prefix].cond = chk_op elif node.op == "NextIteration": op = self._nodes[node.input[0]] assert len(op) == 1 diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index fe163871fa60f..bf1938b1481e7 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -20,7 +20,7 @@ import math import numpy as np import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .. import op as _op @@ -914,5 +914,5 @@ def from_tflite(model, shape_dict, dtype_dict): params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()} outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(ir_pass.free_vars(outputs), outputs) + func = _expr.Function(analysis.free_vars(outputs), outputs) return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 1a5e82269a963..097dbbb8ecaf6 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -79,15 +79,6 @@ def _add(self, var, val, update=False): if isinstance(val, _expr.Expr): if isinstance(var, _base.string_types): var = _expr.GlobalVar(var) - - # TODO(@jroesch): Port this logic to C++. - if not isinstance(val, _expr.Function): - if isinstance(val, _expr.GlobalVar): - val = ir_pass.eta_expand(val, self) - else: - val = _expr.Function([], val) - - _make.Module_Add(self, var, val, update) else: assert isinstance(val, _ty.Type) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index fa70e19544677..b7994217e9640 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,7 @@ from . import _quantize from .. import expr as _expr from .. import module as _module -from .. import ir_pass as _ir_pass +from .. import analysis as _analysis from .. import transform as _transform from .. import op as _op from ... import make as _make @@ -250,7 +250,7 @@ def _make_const(val): const_params[nclip_min] = _make_const(- (valid_range - 1)) const_params[nclip_max] = _make_const((valid_range - 1)) - _ir_pass.post_order_visit(graph, visit_func) + _analysis.post_order_visit(graph, visit_func) return _expr.bind(graph, const_params) diff --git a/python/tvm/relay/testing/dcgan.py b/python/tvm/relay/testing/dcgan.py index 4ee0bd13a5a7e..e9a914ecd69ae 100644 --- a/python/tvm/relay/testing/dcgan.py +++ b/python/tvm/relay/testing/dcgan.py @@ -81,7 +81,7 @@ def get_net(batch_size, random_len=100, oshape=(3, 64, 64), ngf=128, code=None, dc32, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv") tanh = relay.tanh(dc64) - args = relay.ir_pass.free_vars(tanh) + args = relay.analysis.free_vars(tanh) return relay.Function(args, tanh) diff --git a/python/tvm/relay/testing/densenet.py b/python/tvm/relay/testing/densenet.py index de3ebe36eb7bd..573a4bc367946 100644 --- a/python/tvm/relay/testing/densenet.py +++ b/python/tvm/relay/testing/densenet.py @@ -79,7 +79,7 @@ def _make_dense_net(num_init_features, growth_rate, block_config, ret = layers.dense_add_bias(flat, units=classes, name='dense') - return relay.Function(relay.ir_pass.free_vars(ret), ret) + return relay.Function(relay.analysis.free_vars(ret), ret) def get_workload(densenet_size=121, classes=1000, batch_size=4, image_shape=(3, 224, 224), dtype='float32'): diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index 034ac0a6c2e5f..fdf46fbc2f7c2 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -54,7 +54,7 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" relu4 = relay.nn.relu(dense1) dense2 = layers.dense_add_bias(relu4, units=num_actions, name="dense2") - args = relay.ir_pass.free_vars(dense2) + args = relay.analysis.free_vars(dense2) return relay.Function(args, dense2) diff --git a/python/tvm/relay/testing/inception_v3.py b/python/tvm/relay/testing/inception_v3.py index c9ec3293ed0a1..c3f0181f2951f 100644 --- a/python/tvm/relay/testing/inception_v3.py +++ b/python/tvm/relay/testing/inception_v3.py @@ -266,7 +266,7 @@ def get_net(batch_size, fc1 = relay.nn.dense(flatten, relay.var("fc1_weight"), units=num_classes) fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"), axis=-1) inception_v3 = relay.nn.softmax(data=fc1) - args = relay.ir_pass.free_vars(inception_v3) + args = relay.analysis.free_vars(inception_v3) return relay.Function(args, inception_v3) def get_workload(batch_size=1, num_classes=1000, diff --git a/python/tvm/relay/testing/init.py b/python/tvm/relay/testing/init.py index b246b46172766..20b5156990a7c 100644 --- a/python/tvm/relay/testing/init.py +++ b/python/tvm/relay/testing/init.py @@ -150,10 +150,11 @@ def create_workload(net, initializer=None, seed=0): params : dict of str to NDArray The parameters. """ - net = relay.ir_pass.infer_type(net) + mod = relay.Module.from_expr(net) + mod = relay.transform.InferType()(mod) + net = mod[mod.entry_func] shape_dict = { v.name_hint : v.checked_type for v in net.params} - net.astext() np.random.seed(seed) initializer = initializer if initializer else Xavier() params = {} diff --git a/python/tvm/relay/testing/lstm.py b/python/tvm/relay/testing/lstm.py index b0915e033ccbf..9721c26f2a151 100644 --- a/python/tvm/relay/testing/lstm.py +++ b/python/tvm/relay/testing/lstm.py @@ -154,7 +154,7 @@ def get_net(iterations, num_hidden, batch_size=1, dtype="float32"): builder.ret(out) body = builder.get() - args = relay.ir_pass.free_vars(body) + args = relay.analysis.free_vars(body) return relay.Function(args, body, input_type) diff --git a/python/tvm/relay/testing/mlp.py b/python/tvm/relay/testing/mlp.py index 562ef21ba9f1c..e178408a6a1bc 100644 --- a/python/tvm/relay/testing/mlp.py +++ b/python/tvm/relay/testing/mlp.py @@ -58,7 +58,7 @@ def get_net(batch_size, fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes) fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"), axis=-1) mlp = relay.nn.softmax(data=fc3) - args = relay.ir_pass.free_vars(mlp) + args = relay.analysis.free_vars(mlp) return relay.Function(args, mlp) diff --git a/python/tvm/relay/testing/mobilenet.py b/python/tvm/relay/testing/mobilenet.py index 78e1d82456c84..dff103150ab0a 100644 --- a/python/tvm/relay/testing/mobilenet.py +++ b/python/tvm/relay/testing/mobilenet.py @@ -108,7 +108,7 @@ def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224), weight = relay.var('fc_weight') fc = relay.nn.dense(data=flatten, weight=weight, units=num_classes) softmax = relay.nn.softmax(data=fc) - return relay.Function(relay.ir_pass.free_vars(softmax), softmax) + return relay.Function(relay.analysis.free_vars(softmax), softmax) def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtype='float32'): diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index 9ba57ae09ef5b..f67785917384c 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -169,7 +169,7 @@ def resnet(units, flat = relay.nn.batch_flatten(data=pool1) fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1') net = relay.nn.softmax(data=fc1) - return relay.Function(relay.ir_pass.free_vars(net), net) + return relay.Function(relay.analysis.free_vars(net), net) def get_net(batch_size, diff --git a/python/tvm/relay/testing/squeezenet.py b/python/tvm/relay/testing/squeezenet.py index c7b8e8db166b6..5c90265183ff4 100644 --- a/python/tvm/relay/testing/squeezenet.py +++ b/python/tvm/relay/testing/squeezenet.py @@ -119,7 +119,7 @@ def get_net(batch_size, image_shape, num_classes, version, dtype): net = relay.nn.global_avg_pool2d(net) net = relay.nn.batch_flatten(net) net = relay.nn.softmax(net) - args = relay.ir_pass.free_vars(net) + args = relay.analysis.free_vars(net) return relay.Function(args, net) diff --git a/python/tvm/relay/testing/vgg.py b/python/tvm/relay/testing/vgg.py index bec141f70ffd0..06d9aa3d2d93c 100644 --- a/python/tvm/relay/testing/vgg.py +++ b/python/tvm/relay/testing/vgg.py @@ -90,7 +90,7 @@ def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_no feature = get_feature(data, layers, filters, batch_norm) classifier = get_classifier(feature, num_classes) symbol = relay.nn.softmax(data=classifier) - args = relay.ir_pass.free_vars(symbol) + args = relay.analysis.free_vars(symbol) return relay.Function(args, symbol) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index ba4857dc4d36e..255718c627f0e 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -277,6 +277,40 @@ def FoldScaleAxis(): return _transform.FoldScaleAxis() +def BackwardFoldScaleAxis(): + """Backward fold axis scaling into weights of conv2d/dense. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to backward fold expressions. + + Note + ---- + It is recommended to call backward_fold_scale_axis + before using forward_fold_scale_axis. + As backward folding targets common conv-bn pattern. + """ + return _transform.BackwardFoldScaleAxis() + + +def ForwardFoldScaleAxis(): + """Fold the scaling of axis into weights of conv2d/dense. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to forward fold expressions. + + Note + ---- + It is recommended to call backward_fold_scale_axis + before using forward_fold_scale_axis. + As backward folding targets common conv-bn pattern. + """ + return _transform.ForwardFoldScaleAxis() + + def SimplifyInference(): """Simplify the data-flow graph for inference phase. An simplified expression which is semantically equal to the input expression will be returned. @@ -406,7 +440,7 @@ def ToANormalForm(): Returns ------- - ret: tvm.relay.Pass + ret: Union[tvm.relay.Pass, tvm.relay.Expr] The registered pass that transforms an expression into A Normal Form. """ return _transform.ToANormalForm() @@ -454,6 +488,21 @@ def EliminateCommonSubexpr(fskip=None): def PartialEvaluate(): """Evaluate the static fragment of the code. + Note + ---- + This transformation could be either `Module -> Module` or `Expr -> Expr`. + It will directly transform the input expression to a new one if the target + expression is provided. Otherwise, it will rely on the pass manager to + carry out transformation. + + Parameters + ---------- + expr : Optional[tvm.relay.Expr] + The input expression. + + mod : Optional[tvm.relay.Module] + The global module. + Returns ------- ret: tvm.relay.Pass @@ -461,6 +510,7 @@ def PartialEvaluate(): """ return _transform.PartialEvaluate() + def CanonicalizeCast(): """ Canonicalize cast expressions to make operator fusion more efficient. @@ -473,28 +523,35 @@ def CanonicalizeCast(): return _transform.CanonicalizeCast() -def OptimizeOnExpr(expr, passes): - """Perform optimization passes on an expressioin. +def gradient(expr, mod=None, mode='higher_order'): + """ + Transform the input function, + returning a function that calculate the original result, + paired with gradient of the input. Parameters ---------- - expr: tvm.relay.Expr - The expression for optimization. + expr : tvm.relay.Expr + The input expression, which is a Function or a GlobalVar. - passes: Union[Pass, List[Pass]] - The list of optimizations to be applied. + mod : Optional[tvm.relay.Module] + + mode : Optional[String] + The mode of the automatic differentiation algorithm. + 'first_order' only works on first order code, but will not produce + reference nor closure. + 'higher_order' works on all code using reference and closure. Returns ------- - ret: tvm.relay.Expr - The optimized expression. + expr : tvm.relay.Expr + The transformed expression. """ - if isinstance(passes, Pass): - passes = [passes] - if not isinstance(passes, (list, tuple)): - raise TypeError("passes must be a pass or a list of pass objects.") - - return _transform.OptimizeOnExpr(expr, passes) + if mode == 'first_order': + return _transform.first_order_gradient(expr, mod) + if mode == 'higher_order': + return _transform.gradient(expr, mod) + raise Exception('unknown mode') def _wrap_class_module_pass(pass_cls, pass_info): diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 3feb7e4a4b543..3ab57f166d900 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -21,6 +21,7 @@ * \file relay/backend/build_module.cc * \brief Code generation for TVM's graph runtime. */ +#include #include #include #include diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 7ae1befcfe895..83e4a36ff4f93 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 9b510ad2fd293..9765cf90da18a 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -27,8 +27,9 @@ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #include +#include #include -#include +#include #include #include diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 5c2e5c4c289a1..91a597baceaf3 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -25,7 +25,7 @@ */ #include #include -#include +#include #include "../../common/arena.h" namespace tvm { diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index ff2d9e6117abb..7c97befc55f94 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include "compile_engine.h" diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 65a7efd4c2051..139dab21e973d 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -27,7 +27,6 @@ #include #include -#include #include #include #include diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 668c024a8d550..6290ef7c6e932 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/relay/backend/vm/vm.cc b/src/relay/backend/vm/vm.cc index cf0b952005fcb..4dbcda9abb6f9 100644 --- a/src/relay/backend/vm/vm.cc +++ b/src/relay/backend/vm/vm.cc @@ -28,17 +28,18 @@ #include #include #include -#include +#include namespace tvm { namespace relay { namespace vm { +runtime::vm::VirtualMachine CompileModule(const Module& mod); + using tvm::runtime::Object; using tvm::runtime::ObjectTag; using tvm::runtime::vm::VirtualMachine; - VirtualMachine FromModule(const Module& module, const std::vector& ctxs) { auto vm = CompileModule(module); vm.Init(ctxs); diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 81017d4fddfa6..42e66261a5533 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include "type_functor.h" #include "../../lang/attr_functor.h" diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index e09d790822274..36692c5c571b1 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -345,7 +345,7 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_API("relay._ir_pass.post_order_visit") +TVM_REGISTER_API("relay._analysis.post_order_visit") .set_body_typed([](Expr expr, PackedFunc f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index c57475476e589..6039ba272ddc1 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include "type_functor.h" #include "../../lang/attr_functor.h" @@ -412,12 +412,12 @@ size_t StructuralHash::operator()(const Expr& expr) const { return RelayHashHandler().ExprHash(expr); } -TVM_REGISTER_API("relay._ir_pass._expr_hash") +TVM_REGISTER_API("relay._analysis._expr_hash") .set_body_typed([](NodeRef ref) { return static_cast(RelayHashHandler().Hash(ref)); }); -TVM_REGISTER_API("relay._ir_pass._type_hash") +TVM_REGISTER_API("relay._analysis._type_hash") .set_body_typed([](Type type) { return static_cast(RelayHashHandler().TypeHash(type)); }); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 58f614a3cc77c..51a2aeeeb111f 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -23,7 +23,8 @@ * \brief The global module in Relay. */ #include -#include +#include +#include #include namespace tvm { @@ -184,7 +185,26 @@ TVM_REGISTER_API("relay._make.Module") .set_body_typed(ModuleNode::make); TVM_REGISTER_API("relay._make.Module_Add") -.set_body_method(&ModuleNode::Add); +.set_body([](TVMArgs args, TVMRetValue* ret) { + Module mod = args[0]; + GlobalVar var = args[1]; + NodeRef val = args[2]; + bool update = args[3]; + CHECK(val->derived_from()); + if (val->derived_from()) { + mod->Add(var, Downcast(val), update); + } else if (val->derived_from()) { + GlobalVar gv = Downcast(val); + auto mod_copy = Module(make_node(*mod.operator->())); + mod_copy = transform::EtaExpand()(mod_copy); + auto func = mod_copy->Lookup(gv->name_hint); + mod->Add(var, Downcast(func), update); + } else { + auto func = FunctionNode::make({}, Downcast(val), Type(nullptr), {}); + mod->Add(var, func, update); + } + *ret = mod; +}); TVM_REGISTER_API("relay._module.Module_AddDef") .set_body_method(&ModuleNode::AddDef); @@ -197,39 +217,39 @@ TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar") TVM_REGISTER_API("relay._module.Module_Lookup") .set_body_typed([](Module mod, GlobalVar var) { - return mod->Lookup(var); - }); + return mod->Lookup(var); +}); TVM_REGISTER_API("relay._module.Module_Lookup_str") .set_body_typed([](Module mod, std::string var) { - return mod->Lookup(var); - }); + return mod->Lookup(var); +}); TVM_REGISTER_API("relay._module.Module_LookupDef") .set_body_typed([](Module mod, GlobalTypeVar var) { - return mod->LookupDef(var); - }); + return mod->LookupDef(var); +}); TVM_REGISTER_API("relay._module.Module_LookupDef_str") .set_body_typed([](Module mod, std::string var) { - return mod->LookupDef(var); - }); + return mod->LookupDef(var); +}); TVM_REGISTER_API("relay._module.Module_FromExpr") .set_body_typed([](Expr e) { - return ModuleNode::FromExpr(e); + return ModuleNode::FromExpr(e); }); TVM_REGISTER_API("relay._module.Module_Update") .set_body_typed([](Module mod, Module from) { - mod->Update(from); - }); + mod->Update(from); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch( - [](const ModuleNode *node, tvm::IRPrinter *p) { - p->stream << "ModuleNode( " << node->functions << ")"; - }); + [](const ModuleNode *node, tvm::IRPrinter *p) { + p->stream << "ModuleNode( " << node->functions << ")"; +}); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index cc71968fba585..82424500ffc8e 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -24,7 +24,8 @@ other expressions. This pass can be used for computing convolution in custom layouts or other general weight pre-transformation. */ -#include +#include +#include #include #include #include @@ -348,9 +349,6 @@ Expr AlterOpLayout(const Expr& expr) { return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); } -TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") -.set_body_typed(AlterOpLayout); - } // namespace alter_op_layout namespace transform { diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc index 99f4a7f44e7e7..04fec248f81c9 100644 --- a/src/relay/pass/canonicalize_cast.cc +++ b/src/relay/pass/canonicalize_cast.cc @@ -22,7 +22,7 @@ * \file canonicalize_cast.cc * \brief Canonicalize cast expressions to make operator fusion more efficient. */ -#include +#include #include #include #include diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index ff9e2304a3bc3..fc0c43d200e5d 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -23,7 +23,7 @@ * \brief Canonicalize special operators to basic operators. This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) */ -#include +#include #include #include #include @@ -61,9 +61,6 @@ Expr CanonicalizeOps(const Expr& e) { return BiasAddSimplifier().Mutate(e); } -TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") -.set_body_typed(CanonicalizeOps); - namespace transform { Pass CanonicalizeOps() { diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index c95c1ddf8e160..d72705c8ce470 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -33,7 +33,7 @@ * convolution branches, such as Inception block. */ -#include +#include #include #include #include @@ -355,9 +355,6 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { return ParallelConv2DCombiner(min_num_branches).Combine(expr); } -TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") -.set_body_typed(CombineParallelConv2D); - namespace transform { Pass CombineParallelConv2D(uint64_t min_num_branches) { diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 8799bf403375e..54075f0699e6f 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -28,8 +28,9 @@ * CalcDep turn an expr into a dependency graph of expr, * GenLet turn the dependency graph into a let list, taking only the used value. */ -#include +#include #include +#include #include "let_list.h" namespace tvm { diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 8eeb493f1feba..aec974b184d3f 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -34,7 +34,6 @@ #include #include #include -#include #include #include @@ -559,13 +558,13 @@ Map CollectDeviceAnnotationOps(const Expr& expr) { return AnnotatationVisitor::GetAnnotations(expr); } -TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo") +TVM_REGISTER_API("relay._analysis.CollectDeviceInfo") .set_body_typed(CollectDeviceInfo); -TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation") +TVM_REGISTER_API("relay._analysis.RewriteDeviceAnnotation") .set_body_typed(RewriteAnnotatedOps); -TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps") +TVM_REGISTER_API("relay._analysis.CollectDeviceAnnotationOps") .set_body_typed(CollectDeviceAnnotationOps); namespace transform { diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index 883681adcaf45..33a791b2bd996 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -27,7 +27,7 @@ * to replace an expression with a previously appeared expression with the same input and * attributes. The fskip callback argument allows us to skip specific expressions. */ -#include +#include #include #include #include @@ -85,9 +85,6 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { return CommonSubexprEliminator(callback)(expr); } -TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr") -.set_body_typed(EliminateCommonSubexpr); - namespace transform { Pass EliminateCommonSubexpr(PackedFunc fskip) { diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 3139d41d63937..e73e3778395e9 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -25,7 +25,8 @@ * \brief Add abstraction over a function. For example, abs will become (fun x -> abs x). * */ -#include +#include +#include namespace tvm { namespace relay { @@ -44,10 +45,8 @@ Expr EtaExpand(const Expr& e, const Module& mod) { original_type_params = func->type_params; ret_type = func->ret_type; } else { - auto inferred = InferType(e, mod); - CHECK(inferred->is_type()); - - auto func = GetRef(inferred.as_derived()); + CHECK(e->is_type()); + auto func = GetRef(e.as_derived()); original_params = func->params; original_type_params = func->type_params; ret_type = func->ret_type; @@ -62,19 +61,18 @@ Expr EtaExpand(const Expr& e, const Module& mod) { auto new_func = FunctionNode::make(args, CallNode::make(e, params), ret_type, original_type_params); - return InferType(new_func, mod); + return new_func; } -TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand); - namespace transform { Pass EtaExpand() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(EtaExpand(f, m)); - }; - return CreateFunctionPass(pass_func, 1, "EtaExpand", {}); + return Downcast(EtaExpand(f, m)); + }; + Pass expanded = CreateFunctionPass(pass_func, 1, "EtaExpand", {}); + return Sequential({expanded, InferType()}); } TVM_REGISTER_API("relay._transform.EtaExpand") diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc index e86ca06211126..df3a5d7ecec52 100644 --- a/src/relay/pass/feature.cc +++ b/src/relay/pass/feature.cc @@ -23,7 +23,7 @@ * \brief Detect features used in Expr/Module */ #include -#include +#include #include #include #include @@ -97,7 +97,7 @@ Array PyDetectFeature(const Expr& expr, const Module& mod) { return static_cast>(fs); } -TVM_REGISTER_API("relay._ir_pass.detect_feature") +TVM_REGISTER_API("relay._analysis.detect_feature") .set_body_typed(PyDetectFeature); } // namespace relay diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 815407038b082..71d189b0800fc 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -21,7 +21,7 @@ * Copyright (c) 2018 by Contributors * \file constant_folding.cc */ -#include +#include #include #include #include @@ -156,9 +156,13 @@ class ConstantFolder : public ExprMutator { } // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { - expr = InferType(expr, Module(nullptr)); - expr = FuseOps(expr, 0, Module(nullptr)); - expr = InferType(expr, Module(nullptr)); + std::vector passes = {transform::FuseOps(0), + transform::InferType()}; + auto mod = ModuleNode::FromExpr(expr); + auto seq = transform::Sequential(passes); + mod = seq(mod); + auto entry_func = mod->Lookup(mod->entry_func); + expr = expr.as() == nullptr ? entry_func->body : entry_func; return ValueToExpr(executor_(expr)); } // Evaluate shape_of op @@ -213,9 +217,6 @@ Expr FoldConstant(const Expr& expr) { Module(nullptr), ctx, target)).Mutate(expr); } -TVM_REGISTER_API("relay._ir_pass.FoldConstant") -.set_body_typed(FoldConstant); - namespace transform { Pass FoldConstant() { diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 53089807ace5f..868a08f8b5769 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -26,7 +26,7 @@ * conv/dense operators. */ #include -#include +#include #include #include #include @@ -545,10 +545,6 @@ Expr ForwardFoldScaleAxis(const Expr& data) { data, "FScaleAxisForwardRewrite", fcontext); } -// Expose the FoldScaleAxisFoward -TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis") -.set_body_typed(ForwardFoldScaleAxis); - //---------------------------------------- // Implement backward transformations. //---------------------------------------- @@ -947,9 +943,6 @@ Expr BackwardFoldScaleAxis(const Expr& data) { return make_node()->Fold(data); } -TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis") -.set_body_typed(BackwardFoldScaleAxis); - } // namespace fold_scale_axis namespace transform { @@ -964,6 +957,9 @@ Pass ForwardFoldScaleAxis() { {ir::StringImm::make("InferType")}); } +TVM_REGISTER_API("relay._transform.ForwardFoldScaleAxis") +.set_body_typed(ForwardFoldScaleAxis); + Pass BackwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { @@ -974,6 +970,9 @@ Pass BackwardFoldScaleAxis() { {ir::StringImm::make("InferType")}); } +TVM_REGISTER_API("relay._transform.BackwardFoldScaleAxis") +.set_body_typed(BackwardFoldScaleAxis); + Pass FoldScaleAxis() { // FoldScaleAxis pass contains the following three passes. Therefore, we can // register it as a sequential pass. diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 8ad61270e33a8..6c66d6e982a71 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -23,9 +23,9 @@ * \file forward_rewrite.cc * \brief Apply rewriting rules in a forward fashion. */ -#include #include #include +#include #include "pass_util.h" namespace tvm { @@ -206,37 +206,5 @@ Expr ForwardRewrite(const Expr& expr, return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); } -namespace transform { - -using std::function; - -Pass ForwardRewrite(const std::string& rewrite_map_attr_name, - function fcontext, - function fmulti_ref_trigger) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(ForwardRewrite(f, - rewrite_map_attr_name, - fcontext, - fmulti_ref_trigger)); - }; - return CreateFunctionPass(pass_func, 1, "ForwardRewrite", {}); -} - -Pass ForwardRewrite(const FForwardRewrite& rewrite_func, - function fcontext, - function fmulti_ref_trigger) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(ForwardRewrite(f, - rewrite_func, - fcontext, - fmulti_ref_trigger)); - }; - return CreateFunctionPass(pass_func, 1, "ForwardRewriteFunc", {}); -} - -} // namespace transform - } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9f940e54953b9..cdd2837463659 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -26,7 +26,7 @@ * Fuse necessary ops into a single one. */ #include -#include +#include #include #include #include @@ -963,9 +963,6 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { } } -TVM_REGISTER_API("relay._ir_pass.FuseOps") -.set_body_typed(FuseOps); - namespace transform { Pass FuseOps(int fuse_opt_level) { diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 5d26f7adcff77..1abe7a94b621f 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -26,7 +26,8 @@ #include #include #include -#include +#include +#include #include "pattern_util.h" #include "let_list.h" #include "../ir/type_functor.h" @@ -246,7 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_API("relay._ir_pass.first_order_gradient") +TVM_REGISTER_API("relay._analysis.first_order_gradient") .set_body_typed(FirstOrderGradient); struct ReverseADType : TypeMutator { @@ -351,7 +352,7 @@ Expr Gradient(const Expr& re, const Module& mod) { return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_API("relay._ir_pass.gradient") +TVM_REGISTER_API("relay._transform.gradient") .set_body_typed(Gradient); } // namespace relay diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 976a2ef8ec54d..c0f4a7c5967d1 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -32,7 +32,7 @@ * We check this by ensuring the `dtype` field of a Tensor always * contains a data type such as `int`, `float`, `uint`. */ -#include +#include #include #include "../ir/type_functor.h" @@ -183,7 +183,7 @@ Kind KindCheck(const Type& t, const Module& mod) { return kc.Check(t); } -TVM_REGISTER_API("relay._ir_pass.check_kind") +TVM_REGISTER_API("relay._analysis.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = KindCheck(args[0], ModuleNode::make({}, {})); diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index ce70eb0512149..48a0dfb847466 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include "pattern_util.h" @@ -188,7 +188,7 @@ int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); } -TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber") +TVM_REGISTER_API("relay._analysis.GetTotalMacNumber") .set_body_typed(GetTotalMacNumber); } // namespace mac_count diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc index 173d6eacf528f..cc00a54cde0ab 100644 --- a/src/relay/pass/match_exhaustion.cc +++ b/src/relay/pass/match_exhaustion.cc @@ -32,7 +32,6 @@ #include #include #include -#include #include namespace tvm { @@ -236,15 +235,15 @@ Array UnmatchedCases(const Match& match, const Module& mod) { } // expose for testing only -TVM_REGISTER_API("relay._ir_pass.unmatched_cases") -.set_body_typed(const Match&, - const Module&)>([](const Match& match, - const Module& mod_ref) { - Module call_mod = mod_ref; - if (!call_mod.defined()) { - call_mod = ModuleNode::make({}, {}); - } - return UnmatchedCases(match, call_mod); - }); +TVM_REGISTER_API("relay._analysis.unmatched_cases") +.set_body_typed(const Match&, const Module&)>( + [](const Match& match, const Module& mod_ref) { + Module call_mod = mod_ref; + if (!call_mod.defined()) { + call_mod = ModuleNode::make({}, {}); + } + return UnmatchedCases(match, call_mod); + }); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index e7edbb3153d85..acc60982cff44 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -91,7 +91,8 @@ * * These assumptions do not affect the correctness of the algorithm, however. */ -#include +#include +#include #include #include #include @@ -740,9 +741,14 @@ class PartialEvaluator : public ExprFunctor // Constant evaluate a expression. PStatic ConstEvaluate(const Expr& expr, LetList* ll) { - Expr infered = InferType(expr, Module(nullptr)); - Expr fused = FuseOps(infered, 0, Module(nullptr)); - Expr fused_infered = InferType(fused, Module(nullptr)); + std::vector passes = {transform::FuseOps(0), + transform::InferType()}; + auto mod = ModuleNode::FromExpr(expr); + auto seq = transform::Sequential(passes); + mod = seq(mod); + auto entry_func = mod->Lookup(mod->entry_func); + auto fused_infered = + expr.as() == nullptr ? entry_func->body : entry_func; return Reify(executor_(fused_infered), ll); } diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index a620316035c7e..d63d9121fe27e 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -573,18 +573,6 @@ class PassContext::Internal { } }; -Expr OptimizeOnExpr(const Expr& expr, const Array& passes) { - auto mod = ModuleNode::FromExpr(expr); - Sequential seq(passes); - auto pass_ctx = PassContext::Create(); - pass_ctx->opt_level = 3; - tvm::With ctx_scope(pass_ctx); - mod = seq(mod); - CHECK(mod.defined()); - auto entry_func = mod->Lookup(mod->entry_func); - return expr.as() == nullptr ? entry_func->body : entry_func; -} - TVM_REGISTER_API("relay._transform.GetCurrentPassContext") .set_body_typed(PassContext::Current); @@ -594,9 +582,6 @@ TVM_REGISTER_API("relay._transform.EnterPassContext") TVM_REGISTER_API("relay._transform.ExitPassContext") .set_body_typed(PassContext::Internal::ExitScope); -TVM_REGISTER_API("relay._transform.OptimizeOnExpr") -.set_body_typed(OptimizeOnExpr); - } // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 1503d67feaf10..7527d2a216286 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -27,9 +27,10 @@ */ #include #include -#include +#include #include #include +#include #include #include #include @@ -259,6 +260,13 @@ Expr QuantizeRealize(const Call& ref_call, return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); } +Expr FoldConstantOpt(const Expr& expr) { + auto mod = ModuleNode::FromExpr(expr); + mod = transform::FoldConstant()(mod); + auto entry_func = mod->Lookup(mod->entry_func); + return expr.as() == nullptr ? entry_func->body : entry_func; +} + RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .set_attr("FQRealizeRewrite", QuantizeRealize); @@ -290,7 +298,8 @@ Expr Conv2dRealize(const Call& ref_call, Expr ret = CallNode::make(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); - Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale)); + Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); + Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); } @@ -323,7 +332,8 @@ Expr DenseRealize(const Call& ref_call, Expr ret = CallNode::make(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); - Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale)); + Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); + Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); } @@ -356,7 +366,8 @@ Expr MulRealize(const Call& ref_call, } Expr ret = ForwardOp(ref_call, {ldata, rdata}); - Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale)); + Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); + Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExprNode::make(ret, dom_scale, dtype); } CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 6d6b24abec203..daf48c44173e4 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -21,7 +21,7 @@ * Copyright (c) 2018 by Contributors * \file simplify_inference.cc */ -#include +#include #include #include #include @@ -103,9 +103,6 @@ Expr SimplifyInference(const Expr& e) { return InferenceSimplifier().Mutate(e); } -TVM_REGISTER_API("relay._ir_pass.simplify_inference") -.set_body_typed(SimplifyInference); - namespace transform { Pass SimplifyInference() { diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index b5a3f8552d8da..1b4b642eea8cb 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -24,7 +24,7 @@ * * \brief Turn implicit sharing into observable sharing. */ -#include +#include #include #include #include diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index c1ae19e92748e..f6f2a07bc80f4 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -24,6 +24,7 @@ * * \brief Turn A normal form into graph normal form. */ +#include #include #include #include "let_list.h" diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 4b126e5299cfd..ff356cb9c9ef8 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -42,7 +42,7 @@ #include #include #include -#include +#include #include #include "./pass_util.h" #include "type_solver.h" @@ -813,11 +813,6 @@ Function InferType(const Function& func, return Downcast(func_ret); } -TVM_REGISTER_API("relay._ir_pass.infer_type") -.set_body_typed([](const Expr& expr, const Module& mod_ref) { - return InferType(expr, mod_ref); - }); - namespace transform { Pass InferType() { diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 84f72e0d5a008..8289130f53d85 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -512,7 +512,7 @@ bool TypeSolver::Solve() { } // Expose type solver only for debugging purposes. -TVM_REGISTER_API("relay._ir_pass._test_type_solver") +TVM_REGISTER_API("relay._analysis._test_type_solver") .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { using runtime::PackedFunc; using runtime::TypedPackedFunc; diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 8b24e8605f5f7..002ccac356f02 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -27,7 +27,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 3ec4f75cd1ad8..2497197ffbe56 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -24,7 +24,7 @@ * * \brief Utility functions for Relay. */ -#include +#include #include #include #include "pass_util.h" @@ -274,10 +274,10 @@ tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } -TVM_REGISTER_API("relay._ir_pass.free_vars") +TVM_REGISTER_API("relay._analysis.free_vars") .set_body_typed(FreeVars); -TVM_REGISTER_API("relay._ir_pass.bound_vars") +TVM_REGISTER_API("relay._analysis.bound_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; if (x.as_derived()) { @@ -287,10 +287,10 @@ TVM_REGISTER_API("relay._ir_pass.bound_vars") } }); -TVM_REGISTER_API("relay._ir_pass.all_vars") +TVM_REGISTER_API("relay._analysis.all_vars") .set_body_typed(AllVars); -TVM_REGISTER_API("relay._ir_pass.free_type_vars") +TVM_REGISTER_API("relay._analysis.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; Module mod = args[1]; @@ -301,7 +301,7 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars") } }); -TVM_REGISTER_API("relay._ir_pass.bound_type_vars") +TVM_REGISTER_API("relay._analysis.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; Module mod = args[1]; @@ -312,7 +312,7 @@ TVM_REGISTER_API("relay._ir_pass.bound_type_vars") } }); -TVM_REGISTER_API("relay._ir_pass.all_type_vars") +TVM_REGISTER_API("relay._analysis.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; Module mod = args[1]; diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index dea9374812896..bfe8865ab52f9 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -22,7 +22,7 @@ * \file well_formed.cc * \brief check that expression is well formed. */ -#include +#include #include #include #include @@ -78,7 +78,7 @@ bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); } -TVM_REGISTER_API("relay._ir_pass.well_formed") +TVM_REGISTER_API("relay._analysis.well_formed") .set_body_typed(WellFormed); } // namespace relay diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 3f46eed9f10ec..a8a63dd44ef95 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -22,7 +22,8 @@ #include #include #include -#include +#include +#include #include #include #include diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index ffd0f7c4a26fa..8257e94db1977 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -21,7 +21,8 @@ #include #include #include -#include +#include +#include TEST(Relay, SelfReference) { using namespace tvm; @@ -32,10 +33,9 @@ TEST(Relay, SelfReference) { auto y = relay::VarNode::make("y", tensor_type); auto call = relay::CallNode::make(f, Array{ y }); auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); - auto empty_module = - relay::ModuleNode::make(Map{}, - Map{}); - auto type_fx = relay::InferType(fx, empty_module); + auto mod = relay::ModuleNode::FromExpr(fx); + mod = relay::transform::InferType()(mod); + auto type_fx = mod->Lookup(mod->entry_func); auto expected = relay::FuncTypeNode::make(tvm::Array{ tensor_type }, tensor_type, {}, {}); CHECK(AlphaEqual(type_fx->checked_type(), expected)); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index b61a5cc0daade..a943ba29cc923 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include #include @@ -100,7 +100,9 @@ TEST(Relay, Sequential) { relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {}); // Infer type for the expected function. - auto expected = relay::InferType(expected_func, relay::Module(nullptr)); + auto mod1 = relay::ModuleNode::FromExpr(expected_func); + mod1 = relay::transform::InferType()(mod1); + auto expected = mod1->Lookup(mod1->entry_func); CHECK(relay::AlphaEqual(f, expected)); } diff --git a/tests/python/frontend/caffe2/model_zoo/squeezenet.py b/tests/python/frontend/caffe2/model_zoo/squeezenet.py index 74ade8989d053..3c21138343c61 100644 --- a/tests/python/frontend/caffe2/model_zoo/squeezenet.py +++ b/tests/python/frontend/caffe2/model_zoo/squeezenet.py @@ -95,7 +95,7 @@ def get_net(batch_size, image_shape, num_classes, dtype): net = relay.nn.relu(net) net = relay.nn.global_avg_pool2d(net) net = relay.nn.softmax(net, axis=1) - args = relay.ir_pass.free_vars(net) + args = relay.analysis.free_vars(net) return relay.Function(args, net) diff --git a/tests/python/frontend/caffe2/test_graph.py b/tests/python/frontend/caffe2/test_graph.py index ea3a36e606634..98f872ce19b20 100644 --- a/tests/python/frontend/caffe2/test_graph.py +++ b/tests/python/frontend/caffe2/test_graph.py @@ -16,13 +16,15 @@ # under the License. """Test graph equality of caffe2 models.""" from tvm import relay +from tvm.relay import transform from model_zoo import c2_squeezenet, relay_squeezenet -def compare_graph(f1, f2): - f1 = relay.ir_pass.infer_type(f1) - f2 = relay.ir_pass.infer_type(f2) - assert relay.ir_pass.alpha_equal(f1, f2) +def compare_graph(lhs_mod, func): + rhs_mod = relay.Module.from_expr(func) + rhs_mod = transform.InferType()(rhs_mod) + assert relay.analysis.alpha_equal(lhs_mod[lhs_mod.entry_func], + rhs_mod[rhs_mod.entry_func]) def test_squeeze_net(): @@ -31,7 +33,7 @@ def test_squeeze_net(): mod, _, = relay.frontend.from_caffe2( c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict) relay_func, _ = relay_squeezenet() - compare_graph(mod[mod.entry_func], relay_func) + compare_graph(mod, relay_func) if __name__ == '__main__': diff --git a/tests/python/frontend/mxnet/test_graph.py b/tests/python/frontend/mxnet/test_graph.py index b7d3ba4a5b605..37a46f6ce3dc0 100644 --- a/tests/python/frontend/mxnet/test_graph.py +++ b/tests/python/frontend/mxnet/test_graph.py @@ -16,12 +16,11 @@ # under the License. import mxnet as mx from tvm import relay +from tvm.relay import transform import model_zoo def compare_graph(f1, f2): - f1 = relay.ir_pass.infer_type(f1) - f2 = relay.ir_pass.infer_type(f2) - assert relay.ir_pass.alpha_equal(f1, f2) + assert relay.analysis.alpha_equal(f1, f2) def test_mlp(): shape = {"data": (1, 1, 28, 28)} @@ -97,7 +96,10 @@ def relay_compose(F, **kwargs): y = F.var("y", shape=yshape) z = F.split(x, **kwargs) z = F.subtract(F.add(z[0], z[2]), y) - return relay.Function(relay.ir_pass.free_vars(z), z) + func = relay.Function(relay.analysis.free_vars(z), z) + mod = relay.Module.from_expr(func) + mod = transform.InferType()(mod) + return mod[mod.entry_func] mx_sym = mx_compose(mx, num_outputs=3, axis=1) mod, _ = relay.frontend.from_mxnet( diff --git a/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py b/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py index d3538bb0085b6..d59fe1830a184 100644 --- a/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py +++ b/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py @@ -20,7 +20,8 @@ from tvm import relay from tvm import autotvm -from tvm.relay.ir_pass import infer_type, alpha_equal +from tvm.relay import transform +from tvm.relay.analysis import alpha_equal def test_alter_layout_conv2d(): @@ -57,12 +58,11 @@ def convnet(): n15 = relay.reshape(n14, newshape=[1, 1, 3, 3, 224, 224]) n16 = relay.transpose(n15, axes=[0, 1, 4, 2, 5, 3]) net = relay.reshape(n16, newshape=[1, 1, 672, 672]) - args = relay.ir_pass.free_vars(net) + args = relay.analysis.free_vars(net) return relay.Function(args, net) # orig net N = convnet() - N = infer_type(N) # trigger a test # for each known alter_conv2d @@ -75,11 +75,12 @@ def convnet(): for tgt in targets: with tvm.target.create(tgt) as target: with autotvm.tophub.context(target): - O = relay.ir_pass.alter_op_layout(N) - O = relay.ir_pass.infer_type(O) + mod = relay.Module.from_expr(N) + mod = transform.AlterOpLayout()(mod) + O = mod[mod.entry_func] # graph should differ - assert not relay.ir_pass.alpha_equal(N, O) + assert not relay.analysis.alpha_equal(N, O) if __name__ == "__main__": np.random.seed(42) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index f3a08a8698410..390576f87c184 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -14,12 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import infer_type -from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue -from tvm.relay import testing, create_executor +from tvm.relay.backend.interpreter import ConstructorValue +from tvm.relay import create_executor from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr @@ -125,8 +123,14 @@ def test_nat_value(): def test_nat_constructor(): - assert relay.ir_pass.infer_type(z(), mod).checked_type == nat() - assert relay.ir_pass.infer_type(s(z()), mod).checked_type == nat() + func = relay.Function([], z()) + test_z = relay.GlobalVar("test_z") + mod[test_z] = func + assert mod[test_z].body.checked_type == nat() + test_sz = relay.GlobalVar("test_sz") + func = relay.Function([], s(z())) + mod[test_sz] = func + assert mod[test_sz].body.checked_type == nat() def test_double(): @@ -142,8 +146,10 @@ def test_add(): def test_list_constructor(): - a = relay.TypeVar("a") - assert relay.ir_pass.infer_type(cons(z(), nil()), mod).checked_type == l(nat()) + test_consz = relay.GlobalVar("test_consz") + func = relay.Function([], cons(z(), nil())) + mod[test_consz] = func + assert mod[test_consz].body.checked_type == l(nat()) def test_hd_tl(): expected = list(range(10)) diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index f493a9b3f537d..479c4169a9590 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -26,8 +26,10 @@ def get_func(shape): x = relay.var("x", shape=shape) y = relay.add(x, x) z = relay.add(y, x) - f = relay.ir_pass.infer_type(relay.Function([x], z)) - return f + f = relay.Function([x], z) + mod = relay.Module.from_expr(f) + mod = relay.transform.InferType()(mod) + return mod[mod.entry_func] z1 = engine.lower(get_func((10,)), "llvm") z2 = engine.lower(get_func((10,)), "llvm") z3 = engine.lower(get_func(()), "llvm") @@ -55,7 +57,7 @@ def test_compile_placeholder_bypass(): y = relay.var("y", shape=(2, 3)) z = relay.var("z", shape=(2, 3)) result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)]) - func = relay.Function(relay.ir_pass.free_vars(result), result) + func = relay.Function(relay.analysis.free_vars(result), result) with relay.build_config(opt_level=0): graph, lib, params = relay.build(relay.Module.from_expr(func), 'llvm') diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 18e01e39ea276..742e3b4daa9f1 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -19,7 +19,6 @@ import tvm from tvm import relay from tvm.contrib import graph_runtime -from tvm.relay.ir_pass import infer_type from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.op import add from tvm.relay.module import Module @@ -124,9 +123,9 @@ def test_plan_memory(): z = relay.exp(z) z = relay.exp(z) func = relay.Function([x, y], z) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.fuse_ops(func, opt_level=0) - func = relay.ir_pass.infer_type(func) + mod = relay.Module.from_expr(func) + mod = relay.transform.FuseOps(0)(mod) + func = mod[mod.entry_func] smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = set() device_types = set() diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 11ce11e483226..3c79fb7605210 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -227,7 +227,7 @@ def test_tuple_passing(): gv = relay.GlobalVar('fn') mod[gv] = fn mod.entry_func = gv - mod[gv] = relay.ir_pass.infer_type(mod[gv], mod=mod) + mod = relay.transform.InferType()(mod) ctx = tvm.cpu() target = tvm.target.create('llvm') diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index c608ebba9b6d5..aad4856fa9431 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -19,7 +19,10 @@ def check_type_err(expr, msg): try: - expr = relay.ir_pass.infer_type(expr) + mod = relay.Module.from_expr(expr) + mod = relay.transform.InferType()(mod) + entry = mod[mod.entry_func] + expr = entry if isinstance(expr, relay.Function) else entry.body assert False except tvm.TVMError as err: assert msg in str(err) diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_feature.py index 637e184704f2d..9b5010286d4f6 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_feature.py @@ -17,7 +17,8 @@ import tvm from tvm import relay -from tvm.relay.ir_pass import detect_feature, gradient +from tvm.relay.analysis import detect_feature +from tvm.relay.transform import gradient from tvm.relay.feature import Feature from tvm.relay.prelude import Prelude @@ -46,7 +47,9 @@ def test_ad(): t = relay.TensorType(shape, dtype) x = relay.var("x", t) func = relay.Function([x], x + x) - back_func = relay.ir_pass.infer_type(gradient(func)) + mod = relay.Module.from_expr(gradient(func)) + mod = relay.transform.InferType()(mod) + back_func = mod[mod.entry_func] feats = detect_feature(back_func) assert feats == set([ Feature.fVar, diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py index 754efa557db6e..df280e2fa2482 100644 --- a/tests/python/relay/test_ir_bind.py +++ b/tests/python/relay/test_ir_bind.py @@ -28,11 +28,11 @@ def test_bind_params(): fexpected =relay.Function( [y], relay.add(relay.const(1, "float32"), y)) - assert relay.ir_pass.alpha_equal(fbinded, fexpected) + assert relay.analysis.alpha_equal(fbinded, fexpected) zbinded = relay.bind(z, {y: x}) zexpected = relay.add(x, x) - assert relay.ir_pass.alpha_equal(zbinded, zexpected) + assert relay.analysis.alpha_equal(zbinded, zexpected) if __name__ == "__main__": diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index cec277371252b..b42a1e6d52c6d 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -19,7 +19,7 @@ from tvm import relay from tvm.expr import * from tvm.relay import op -from tvm.relay.ir_pass import graph_equal +from tvm.relay.analysis import graph_equal def check_json_roundtrip(node): diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 79b010ba0cb06..5f1f65ffb47c8 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import relay -from tvm.relay.ir_pass import alpha_equal +from tvm.relay.analysis import alpha_equal from nose.tools import nottest, raises from numpy import isclose from typing import Union diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index 3cf73ae2cc667..bee0a021ac5b6 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import relay -from tvm.relay.ir_pass import well_formed +from tvm.relay.analysis import well_formed from tvm.relay.prelude import Prelude def test_let(): diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 072271218bdf4..7da623a45ce6e 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -14,16 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm import numpy as np +import tvm from tvm import relay -from tvm.relay.ir_pass import gradient, infer_type +from tvm.relay.transform import gradient from tvm.relay.testing import ctx_list + +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = relay.transform.InferType()(mod) + return mod[mod.entry_func] + + def sigmoid(x): one = np.ones_like(x) return one / (one + np.exp(-x)) + def relu(x): x_copy = np.copy(x) np.maximum(x_copy, 0, x_copy) @@ -41,7 +49,7 @@ def check_single_op(opfunc, ref): data = np.random.rand(*shape).astype(dtype) ref_grad = ref(data) fwd_func = relay.Function([x], y) - bwd_func = infer_type(gradient(fwd_func)) + bwd_func = run_infer_type(gradient(fwd_func)) for target, ctx in ctx_list(): intrp = relay.create_executor(ctx=ctx, target=target) @@ -73,7 +81,7 @@ def check_binary_op(opfunc, ref): y_data = np.random.rand(*s).astype(t.dtype) ref_grad0, ref_grad1 = ref(x_data, y_data) fwd_func = relay.Function([x, y], z) - bwd_func = infer_type(gradient(fwd_func)) + bwd_func = run_infer_type(gradient(fwd_func)) for target, ctx in ctx_list(): intrp = relay.create_executor(ctx=ctx, target=target) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 202464493d4b1..8baec8c79e9ab 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -14,13 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import math -import tvm import numpy as np +import tvm from tvm import relay +from tvm.relay import transform from tvm.relay.testing import ctx_list import topi.testing +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + def sigmoid(x): one = np.ones_like(x) return one / (one + np.exp(-x)) @@ -44,7 +50,8 @@ def check_single_op(opfunc, ref): # test printer assert ("{}(%x)".format(y.op.name)) in y.astext() # test type inference - assert relay.ir_pass.infer_type(y).checked_type == tp + yy = run_infer_type(y) + assert yy.checked_type == tp if ref is not None: data = np.random.rand(*shape).astype(dtype) @@ -84,7 +91,8 @@ def check_binary_op(opfunc, ref): z = opfunc(x, y) # test printer assert ("{}(%x, %y)".format(z.op.name)) in z.astext() - assert relay.ir_pass.infer_type(z).checked_type == t1 + zz = run_infer_type(z) + assert zz.checked_type == t1 if ref is not None: t1 = relay.TensorType((5, 10, 5)) @@ -134,7 +142,7 @@ def test_bias_add(): x = relay.var("x", shape=xshape) bias = relay.var("bias") z = relay.nn.bias_add(x, bias) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert "axis=" not in zz.astext() assert zz.args[1].checked_type == relay.TensorType(bshape) @@ -153,8 +161,8 @@ def test_expand_dims_infer_type(): x = relay.var("x", shape=(n, t, d)) y = relay.expand_dims(x, axis=2) assert "axis=2" in y.astext() - checked = relay.ir_pass.infer_type(y) - assert checked.checked_type == relay.TensorType((n, t, 1, 100)) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, t, 1, 100)) def test_softmax(): @@ -162,7 +170,7 @@ def test_softmax(): x = relay.var("x", shape=shape) y = relay.nn.softmax(x, axis=1) assert "nn.softmax" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType(shape) func = relay.Function([x], y) x_data = np.random.uniform(size=shape).astype("float32") @@ -178,7 +186,7 @@ def test_log_softmax(): x = relay.var("x", shape=shape) y = relay.nn.log_softmax(x, axis=1) assert "nn.log_softmax" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType(shape) func = relay.Function([x], y) x_data = np.random.uniform(size=shape).astype("float32") @@ -195,16 +203,16 @@ def test_concatenate(): y = relay.var("y", shape=(n, t, d)) z = relay.concatenate((x, y), axis=-1) assert "axis=" in z.astext() - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, t, 200)) x = relay.exp(x) z = relay.concatenate((x, y), axis=2) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, t, 200)) z = relay.concatenate((x, y), axis=1) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, t + t, 100)) x = relay.var("x", shape=(10, 5)) @@ -233,7 +241,7 @@ def test_dropout(): x = relay.var("x", input_ty) y = relay.nn.dropout(x, rate=0.75) assert "rate=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == input_ty @@ -246,7 +254,7 @@ def test_batch_norm(): moving_var = relay.var("moving_var", relay.TensorType((2,))) y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, center=False, scale=False) - yy = relay.ir_pass.infer_type(y.astuple()) + yy = run_infer_type(y.astuple()) assert "center=" in yy.astext() assert yy.checked_type == relay.ty.TupleType(tvm.convert([ relay.TensorType((3, 2, 1), "float32"), @@ -261,7 +269,7 @@ def test_batch_norm(): y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=0, center=False, scale=False) - yy = relay.ir_pass.infer_type(y.astuple()) + yy = run_infer_type(y.astuple()) assert yy.checked_type == relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((3, 2, 1), "float32"), relay.ty.TensorType((3,), "float32"), @@ -276,7 +284,7 @@ def test_batch_norm(): moving_var = relay.var("moving_var", relay.TensorType((3,))) y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=-1, center=False, scale=False) - yy = relay.ir_pass.infer_type(y.astuple()) + yy = run_infer_type(y.astuple()) assert yy.checked_type == relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((1, 2, 3), "float32"), relay.ty.TensorType((3,), "float32"), @@ -290,7 +298,7 @@ def test_dense(): w = relay.var("w", relay.TensorType((2, w), "float32")) y = relay.nn.dense(x, w, units=2) "units=2" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 @@ -298,14 +306,14 @@ def test_dense(): wh, ww = tvm.var("wh"), tvm.var("ww") w = relay.var("w", relay.TensorType((ww, wh), "float32")) y = relay.nn.dense(x, w) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h, ww), "float32") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) w = relay.var("w", relay.IncompleteType()) y = relay.nn.dense(x, w, units=2) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") x = relay.var("x", shape=(10, 5)) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index f904fb01fdb90..bcf6b7f80abd5 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -20,10 +20,17 @@ import tvm import topi.testing from tvm import relay +from tvm.relay import transform from tvm.relay.testing import ctx_list import topi import topi.testing +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + def test_collapse_sum_like(): shape = (3, 4, 5, 6) shape_like = (4, 5, 6) @@ -31,7 +38,7 @@ def test_collapse_sum_like(): x = relay.Var("x", relay.ty.TensorType(shape , dtype)) y = relay.Var("y", relay.ty.TensorType(shape_like, dtype)) z = relay.collapse_sum_like(x, y) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(shape_like, dtype) func = relay.Function([x, y], z) @@ -50,7 +57,7 @@ def test_broadcast_to(): dtype = "float32" x = relay.Var("x", relay.ty.TensorType(shape , dtype)) z = relay.broadcast_to(x, shape=shape_like) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(shape_like, dtype) func = relay.Function([x], z) @@ -69,7 +76,7 @@ def test_broadcast_to_like(): x = relay.Var("x", relay.ty.TensorType(shape , dtype)) y = relay.Var("y", relay.ty.TensorType(shape_like, dtype)) z = relay.broadcast_to_like(x, y) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(shape_like, dtype) func = relay.Function([x, y], z) @@ -106,7 +113,7 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"): x = relay.var("data", relay.TensorType(data, dtype)) y = relay.var("slice_like", relay.TensorType(slice_like, dtype)) z = relay.slice_like(x, y, axes) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) if axes: assert "axes" in z.astext() assert zz.checked_type == relay.ty.TensorType(output, dtype) @@ -144,7 +151,7 @@ def test_reverse_reshape(): def verify_reverse_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) z = relay.reverse_reshape(x, newshape=newshape) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert "newshape=" in z.astext() assert zz.checked_type == relay.ty.TensorType(oshape, "float32") @@ -166,7 +173,7 @@ def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): x = relay.var("x", relay.TensorType(x_shape, dtype)) y = relay.var("y", relay.TensorType(y_shape, dtype)) z = relay.nn.batch_matmul(x, y) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(out_shape, dtype) func = relay.Function([x, y], z) @@ -185,7 +192,7 @@ def test_batch_matmul(): x = relay.var("x", relay.TensorType((b, m, k), "float32")) y = relay.var("y", relay.TensorType((b, n, k), "float32")) z = relay.nn.batch_matmul(x, y) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((b, m, n), "float32") verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) @@ -197,7 +204,7 @@ def test_shape_of(): shape = (10, 5, 12) x = relay.var("x", shape=shape) func = relay.Function([x], relay.op.shape_of(x)) - func = relay.ir_pass.infer_type(func) + func = run_infer_type(func) x_data = np.random.rand(*shape).astype('float32') for target, ctx in ctx_list(): # Because using graph executor, this op will be optimized after @@ -256,7 +263,8 @@ def _verify(data_shape, mask_value, axis, dtype, itype): data = relay.var("data", relay.TensorType(data_shape, dtype)) valid_length = relay.var("valid_length", relay.TensorType((nbatch,), itype)) out = relay.sequence_mask(data, valid_length, mask_value, axis) - assert relay.ir_pass.infer_type(out).checked_type == relay.ty.TensorType(data_shape, dtype) + checked = run_infer_type(out) + assert checked.checked_type == relay.ty.TensorType(data_shape, dtype) func = relay.Function([data, valid_length], out) data_np = np.random.uniform(size=data_shape).astype(dtype) valid_length_np = np.random.randint(0, max_length, size=nbatch).astype(itype) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index c8f5b1d27a2a8..722e8d178fab7 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -16,12 +16,19 @@ # under the License. """ Support level2 operator test cases. """ +import numpy as np import tvm from tvm import relay +from tvm.relay import transform from tvm.relay.testing import ctx_list -import numpy as np import topi.testing +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + def test_conv2d_infer_type(): # symbolic in batch dimension n, c, h, w = tvm.var("n"), 10, 224, 224 @@ -31,7 +38,7 @@ def test_conv2d_infer_type(): kernel_size=(3, 3), padding=(1, 1), channels=2) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (n, 2, 224, 224), "float32") assert yy.args[1].checked_type == relay.TensorType( @@ -44,7 +51,7 @@ def test_conv2d_infer_type(): w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8")) y = relay.nn.conv2d(x, w, out_dtype="int32") assert "out_dtype=\"int32\"" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (n, 2, 222, 222), "int32") @@ -59,7 +66,7 @@ def test_conv2d_infer_type(): data_layout="NCHW4n4c", kernel_layout="OIHW4o4i", out_dtype="int32") - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (1, 4, 224, 224, 4, 4), "int32") assert yy.args[1].checked_type == relay.TensorType( @@ -75,7 +82,7 @@ def test_conv2d_infer_type(): channels=16, data_layout="NHWC", out_dtype="int32") - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (n, h, w, 16), "int32") @@ -169,7 +176,7 @@ def test_conv2d_transpose_infer_type(): padding=(1, 1), channels=15) assert "channels=15" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (n, 15, 10, 12), "float32") assert yy.args[1].checked_type == relay.TensorType( @@ -183,7 +190,7 @@ def test_conv2d_transpose_infer_type(): output_padding=(1, 1), channels=11, data_layout="NHWC") - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (n, 15, 15, 11), "float32") @@ -219,12 +226,12 @@ def test_upsampling_infer_type(): x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR") "method=\"BINLINEAR\"" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32") n, c = tvm.var("n"), tvm.var("c") x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR") - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") @@ -233,7 +240,7 @@ def _test_pool2d(opfunc, reffunc): x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) y = opfunc(x, pool_size=(1, 1)) assert "pool_size=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, 10, 224, 224), "float32") # test execution dtype = "float32" @@ -253,13 +260,13 @@ def _test_global_pool2d(opfunc, reffunc): n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224 x = relay.var("x", relay.TensorType((n, h, w, c), "float32")) y = opfunc(x, layout="NHWC") - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, 1, 1, c), "float32") n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) y = opfunc(x) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 1, 1), "float32") # test execution dtype = "float32" @@ -320,17 +327,17 @@ def test_flatten_infer_type(): d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32")) y = relay.nn.batch_flatten(x) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((d1, ((d2*d3)*d4)), "float32") x = relay.var("x", relay.TensorType((3, 2, 4, 3), "float32")) y = relay.nn.batch_flatten(x) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((3, 24), "float32") x = relay.var("x", relay.TensorType((d1, 2, d3, 3), "float32")) y = relay.nn.batch_flatten(x) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((d1, ((2*d3)*3)), "float32") shape = (1, 5, 10, 10) @@ -338,7 +345,7 @@ def test_flatten_infer_type(): dtype = "float32" x = relay.var("x", relay.TensorType(shape, dtype)) z = relay.nn.batch_flatten(x) - yy = relay.ir_pass.infer_type(z) + yy = run_infer_type(z) assert yy.checked_type == relay.TensorType(o_shape, dtype) func = relay.Function([x], z) x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) @@ -358,14 +365,14 @@ def test_pad_infer_type(): t = relay.var("t", relay.TensorType((n, c, h, w), "float32")) y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))) "pad_width=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((3, 6, 9, 12), "float32") # some symbolic values n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") t = relay.var("t", relay.TensorType((n, c, h, w), "float32")) y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32") def test_pad_run(): @@ -389,7 +396,7 @@ def test_lrn(): x = relay.var("x", shape=(n, c , h, w)) y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=.00001, beta=0.75) "alpha=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c , h, w)) shape = (1, 5, 10, 10) @@ -401,7 +408,7 @@ def test_lrn(): alpha=.00001 beta=0.75 z = relay.nn.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta) - yy = relay.ir_pass.infer_type(z) + yy = run_infer_type(z) assert yy.checked_type == relay.TensorType(shape, dtype) func = relay.Function([x], z) x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) @@ -420,7 +427,7 @@ def test_l2_normalize(): x = relay.var("x", shape=(n, c , h, w)) y = relay.nn.l2_normalize(x, eps=0.001, axis=[1]) "axis=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c , h, w)) shape = (1, 5, 10, 10) @@ -429,7 +436,7 @@ def test_l2_normalize(): eps=0.001 axis=1 z = relay.nn.l2_normalize(x, eps=0.001, axis=[axis]) - yy = relay.ir_pass.infer_type(z) + yy = run_infer_type(z) assert yy.checked_type == relay.TensorType(shape, dtype) func = relay.Function([x], z) x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) @@ -477,7 +484,7 @@ def get_shape(): ishape, oshape = get_shape() x = relay.var("x", relay.TensorType((n,) + ishape, dtype)) y = relay.nn.upsampling(x, scale=scale, layout=layout, method=method) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n,) + oshape, dtype) dshape = (1,) + ishape x = relay.var("x", shape=dshape) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index fcd4caff2695f..575996fbe61ee 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -16,17 +16,23 @@ # under the License. """ Support level3 operator test cases. """ -import tvm import numpy as np +from nose.tools import raises +import tvm from tvm import relay -from tvm.relay import create_executor +from tvm.relay import create_executor, transform from tvm.relay.testing import ctx_list -from nose.tools import raises + +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body def test_zeros_ones(): for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]: y = op(shape=(124, 50), dtype="float64") - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((124, 50), "float64") intrp = create_executor() intrp_res = intrp.evaluate(y).asnumpy() @@ -46,7 +52,7 @@ def test_unary_identity(): shape = (8, 9, 4) x = relay.var("x", relay.TensorType(shape, "float32")) y = op(x) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType(shape, "float32") if ref is not None: @@ -59,20 +65,20 @@ def test_unary_identity(): def test_cast(): x = relay.var("x", relay.TensorType((8, 9, 4), "float32")) y = x.astype("int32") - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert "dtype=" in yy.astext() assert yy.checked_type == relay.TensorType((8, 9, 4), "int32") x = relay.var("x", relay.TensorType((8, 9, 4), "float32")) y = relay.cast(x, "int32") - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert "dtype=" in yy.astext() assert yy.checked_type == relay.TensorType((8, 9, 4), "int32") def test_clip(): a = relay.var("a", relay.TensorType((10, 4), "float32")) y = relay.clip(a, 1., 4.) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((10, 4), "float32") data = np.random.rand(10, 4).astype('float32') @@ -105,13 +111,13 @@ def test_transpose_infer_type(): x = relay.var("x", relay.TensorType((n, t, d), "float32")) y = relay.transpose(x, axes=(1, 0, 2)) assert "axes=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (t, n, 100), "float32") y = relay.transpose(x) assert "axes=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (100, t, n), "float32") @@ -138,7 +144,7 @@ def test_squeeze_infer_type(): x = relay.var("x", relay.TensorType((n, t, d), "float32")) y = relay.squeeze(x, axis=(2,)) assert "axis=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (1, 4), "float32") @@ -146,7 +152,7 @@ def test_squeeze_infer_type(): x = relay.var("x", relay.TensorType((n, t, d), "float32")) y = relay.squeeze(x) assert "axis=" not in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (4,), "float32") @@ -156,7 +162,7 @@ def test_squeeze_bad_axes_infer_type(): n, t, d = 1, 4, 1 x = relay.var("x", relay.TensorType((n, t, d), "float32")) y = relay.squeeze(x, axis=(1,)) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) def test_reshape_infer_type(): @@ -164,7 +170,7 @@ def test_reshape_infer_type(): x = relay.var("x", relay.TensorType((n, t, d1, d2), "float32")) y = relay.reshape(x, newshape=(n, t, 2000)) assert "newshape=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType( (n, t, 2000), "float32") @@ -172,7 +178,7 @@ def test_reshape(): def verify_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) z = relay.reshape(x, newshape=newshape) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert "newshape=" in z.astext() assert zz.checked_type == relay.ty.TensorType(oshape, "float32") @@ -205,7 +211,7 @@ def test_reshape_like_infer_type(): x = relay.var("x", relay.TensorType((1, 2, 3), "float32")) y = relay.var("y", relay.TensorType((1,6), "float32")) z = relay.reshape_like(x, y) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((1, 6), "float32") # symbolic shape @@ -213,7 +219,7 @@ def test_reshape_like_infer_type(): x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) y = relay.var("y", relay.TensorType((1, 8, 8), "float32")) z = relay.reshape_like(x, y) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((1, 8, 8), "float32") @@ -226,7 +232,7 @@ def verify_reshape_like(shape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) y = relay.var("x", relay.TensorType(oshape, "float32")) z = relay.reshape_like(x, y) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32") func = relay.Function([x, y], z) @@ -245,8 +251,7 @@ def verify_take(dshape, indices_shape, oshape, axis=None): x = relay.var("x", relay.TensorType(dshape, "float32")) indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) y = relay.take(x, indices, axis=axis) - y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType(oshape, "float32") d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3") @@ -301,8 +306,7 @@ def test_split_infer_type(): def verify_split(dshape, indices_or_sections, ret_type, axis=None): x = relay.var("x", relay.ty.TensorType(dshape, "float32")) y = relay.split(x, indices_or_sections, axis=axis) - y.astext() - yy = relay.ir_pass.infer_type(y.astuple()) + yy = run_infer_type(y.astuple()) assert yy.checked_type == ret_type d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") @@ -347,14 +351,14 @@ def test_full_infer_type(): # default settings: match input dtype x = relay.var("x", relay.TensorType((), "int8")) y = relay.full(x, ()) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((), "int8") # change the shape and dtype x = relay.var("x", relay.TensorType((), "float32")) y = relay.full(x, (1, 2), "int8") "shape=" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((1, 2), "int8") @@ -378,7 +382,7 @@ def test_full_like_infer_type(): base = relay.var("base", relay.TensorType((1, 2, 3), "float32")) fill = relay.var("fill", relay.TensorType((), "float32")) y = relay.full_like(base, fill) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((1, 2, 3), "float32") # symbolic shape @@ -386,7 +390,7 @@ def test_full_like_infer_type(): base = relay.var("base", relay.TensorType((n, c, h, w), "float32")) fill = relay.var("fill", relay.TensorType((), "float32")) y = relay.full_like(base, fill) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") @@ -414,7 +418,7 @@ def test_infer_type_leaky_relu(): x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) y = relay.nn.leaky_relu(x, alpha=0.1) "alpha=0.1" in y.astext() - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") shape = (1, 5, 10, 10) @@ -422,8 +426,8 @@ def test_infer_type_leaky_relu(): x = relay.var("x", relay.TensorType(shape, dtype)) z = relay.nn.leaky_relu(x, alpha=0.1) assert "alpha=0.1" in z.astext() - yy = relay.ir_pass.infer_type(z) - assert yy.checked_type == relay.TensorType(shape, dtype) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType(shape, dtype) func = relay.Function([x], z) x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) ref_res = np.where(x_data > 0, x_data, x_data * 0.1) @@ -443,7 +447,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"): else: y = relay.var("alpha", relay.IncompleteType()) z = relay.nn.prelu(x, y, axis=axis) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) if axis != 1: assert "axis" in z.astext() assert zz.checked_type == relay.ty.TensorType(output, dtype) @@ -577,7 +581,7 @@ def test_reverse(): def verify_reverse(dshape, axis): x = relay.var("x", relay.TensorType(dshape, "float32")) z = relay.reverse(x, axis=axis) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) func = relay.Function([x], z) x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index da0fe01063f4a..9bab5d87389a3 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -17,9 +17,16 @@ import tvm import numpy as np from tvm import relay +from tvm.relay import transform from tvm.relay.testing import ctx_list import topi.testing +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + def test_binary_op(): def check_binary_op(opfunc, ref): n = tvm.var("n") @@ -30,7 +37,8 @@ def check_binary_op(opfunc, ref): z = opfunc(x, y) # test printer assert ("{}(%x, %y)".format(z.op.name)) in z.astext() - assert relay.ir_pass.infer_type(z).checked_type == t1 + zz = run_infer_type(z) + assert zz.checked_type == t1 if ref is not None: t1 = relay.TensorType((5, 10, 5)) @@ -62,8 +70,7 @@ def test_cmp_type(): x = relay.var("x", relay.TensorType((10, 4), "float32")) y = relay.var("y", relay.TensorType((5, 10, 1), "float32")) z = op(x, y) - z.astext() - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((5, 10, 4), "bool") if ref is not None: @@ -94,7 +101,7 @@ def test_binary_int_broadcast(): x = relay.var("x", relay.TensorType((10, 4), "int32")) y = relay.var("y", relay.TensorType((5, 10, 1), "int32")) z = op(x, y) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((5, 10, 4), "int32") if ref is not None: @@ -120,7 +127,7 @@ def test_where(): x = relay.var("x", relay.TensorType(shape, dtype)) y = relay.var("y", relay.TensorType(shape, dtype)) z = relay.where(cond, x, y) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(shape, dtype) func = relay.Function([cond, x, y], z) @@ -142,7 +149,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") x = relay.var("x", relay.TensorType(data, dtype)) z = test_func(x, axis, keepdims, exclude) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) if axis: assert "axis=" in z.astext() if keepdims: @@ -224,7 +231,7 @@ def verify(dshape, begin, end, strides, output, test_ref=True): x = relay.var("x", relay.TensorType(dshape, "float32")) z = relay.strided_slice(x, begin=begin, end=end, strides=strides) func = relay.Function([x], z) - func = relay.ir_pass.infer_type(func) + func = run_infer_type(func) text = func.astext() assert "begin=" in text assert "end=" in text diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 3d9ec6dde4adc..cd008e3d19a3a 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -20,21 +20,28 @@ import numpy as np import tvm from tvm import relay +from tvm.relay import transform from tvm.relay.testing import ctx_list import topi.testing +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + def test_resize_infer_type(): n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) th, tw = tvm.var("th"), tvm.var("tw") z = relay.image.resize(x, (th, tw)) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) z= relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False) assert "size=" in z.astext() - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") def test_resize(): @@ -52,7 +59,7 @@ def verify_resize(dshape, scale, method, layout): x = relay.var("x", relay.TensorType(dshape, "float32")) z = relay.image.resize(x, size, layout, method, False) assert "size=" in z.astext() - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") func = relay.Function([x], z) @@ -109,7 +116,7 @@ def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,), check_type_only=False): z = relay.vision.multibox_prior(x, sizes, ratios, steps, offsets, clip) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) if check_size: assert "sizes=" in z.astext() assert zz.checked_type == relay.TensorType( @@ -121,7 +128,7 @@ def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,), data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") func = relay.Function([x], z) - func = relay.ir_pass.infer_type(func) + func = run_infer_type(func) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) @@ -176,7 +183,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): z = relay.vision.get_valid_counts(x, score_threshold, id_index, score_index) assert "score_threshold" in z.astext() func = relay.Function([x], z.astuple()) - func = relay.ir_pass.infer_type(func) + func = run_infer_type(func) for target, ctx in ctx_list(): if target == 'cuda': return @@ -205,8 +212,8 @@ def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res, top_k = top_k) assert "iou_threshold" in z.astext() assert "iou_threshold" in z_indices.astext() - zz = relay.ir_pass.infer_type(z) - zz_indices = relay.ir_pass.infer_type(z_indices) + zz = run_infer_type(z) + zz_indices = run_infer_type(z_indices) assert zz.checked_type == relay.ty.TensorType(dshape, "float32") assert zz_indices.checked_type == relay.ty.TensorType((dshape[0], dshape[1]), "int32") @@ -214,9 +221,9 @@ def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res, return func = relay.Function([x0, x1], z) - func = relay.ir_pass.infer_type(func) + func = run_infer_type(func) func_indices = relay.Function([x0, x1], z_indices) - func_indices = relay.ir_pass.infer_type(func_indices) + func_indices = run_infer_type(func_indices) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x0_data, x1_data) @@ -288,7 +295,7 @@ def test_default_value(): mtl = relay.vision.multibox_transform_loc( cls_prob=cls_prob, loc_pred=loc_pred, anchor=anchors) - ret = relay.ir_pass.infer_type(mtl.astuple()) + ret = run_infer_type(mtl.astuple()) ref_type = relay.ty.TupleType( tvm.convert([ relay.ty.TensorType((1, num_anchors, 6), "float32"), @@ -299,7 +306,7 @@ def test_default_value(): nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False) func = relay.Function([cls_prob, loc_pred, anchors], nms) - func = relay.ir_pass.infer_type(func) + func = run_infer_type(func) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds, @@ -330,7 +337,7 @@ def test_threshold(): anchor=anchors, threshold=threshold, variances=variances) - ret = relay.ir_pass.infer_type(ret.astuple()) + ret = run_infer_type(ret.astuple()) ref_type = relay.ty.TupleType( tvm.convert([ relay.ty.TensorType((n, num_anchors, 6), "float32"), @@ -349,15 +356,14 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ z = relay.vision.roi_align(data, rois, pooled_size=(pooled_size, pooled_size), spatial_scale=spatial_scale, sample_ratio=sample_ratio, layout="NCHW") - zz = relay.ir_pass.infer_type(z) - + zz = run_infer_type(z) batch, channel, in_size, _ = data_shape num_roi = rois_shape[0] assert zz.checked_type == relay.ty.TensorType( (num_roi, channel, pooled_size, pooled_size), "float32") func = relay.Function([data, rois], z) - func = relay.ir_pass.infer_type(func) + func = run_infer_type(func) np_data = np.random.uniform(size=data_shape).astype("float32") np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi) @@ -382,15 +388,14 @@ def verify_roi_pool(data_shape, rois_shape, pooled_size, spatial_scale): rois = relay.var("rois", relay.ty.TensorType(rois_shape, "float32")) z = relay.vision.roi_pool(data, rois, pooled_size=(pooled_size, pooled_size), spatial_scale=spatial_scale, layout="NCHW") - zz = relay.ir_pass.infer_type(z) - + zz = run_infer_type(z) batch, channel, in_size, _ = data_shape num_roi = rois_shape[0] assert zz.checked_type == relay.ty.TensorType( (num_roi, channel, pooled_size, pooled_size), "float32") func = relay.Function([data, rois], z) - func = relay.ir_pass.infer_type(func) + func = run_infer_type(func) np_data = np.random.uniform(size=data_shape).astype("float32") np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi).astype('float32') @@ -414,12 +419,11 @@ def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): bbox_pred = relay.var("bbox_pred", relay.ty.TensorType(np_bbox_pred.shape, "float32")) im_info = relay.var("im_info", relay.ty.TensorType(np_im_info.shape, "float32")) z = relay.vision.proposal(cls_prob, bbox_pred, im_info, **attrs) - zz = relay.ir_pass.infer_type(z) - + zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(np_out.shape, "float32") func = relay.Function([cls_prob, bbox_pred, im_info], z) - func = relay.ir_pass.infer_type(func) + func = run_infer_type(func) for target in ['cuda']: if not tvm.module.enabled(target): print("Skip test because %s is not enabled." % target) @@ -478,7 +482,7 @@ def test_yolo_reorg_infer_shape(): def verify_yolo_reorg(shape, stride, out_shape): x = relay.var("x", relay.TensorType(shape, "float32")) z = relay.vision.yolo_reorg(x, stride=stride) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert "stride=" in z.astext() assert zz.checked_type == relay.ty.TensorType(out_shape, "float32") @@ -493,7 +497,7 @@ def verify_yolo_reorg(shape, stride): x = relay.var("x", relay.TensorType(shape, "float32")) z = relay.vision.yolo_reorg(x, stride=stride) - zz = relay.ir_pass.infer_type(z) + zz = run_infer_type(z) assert "stride=" in z.astext() assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32") @@ -527,7 +531,7 @@ def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, gro weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1]) out_shape = (batch, out_channel, size, size) offset_shape = (batch, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, out_shape[2], out_shape[3]) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.TensorType(out_shape) assert yy.args[1].checked_type == relay.TensorType(offset_shape), yy.args[1].checked_type assert yy.args[2].checked_type == relay.TensorType(weight_shape) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 0e0036565363e..de764f849c1c7 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -14,17 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm import numpy as np +import tvm from tvm import relay -from tvm.relay import ir_pass +from tvm.relay import analysis def alpha_equal(x, y): """ Wrapper around alpha equality which ensures that the hash function respects equality. """ - return ir_pass.alpha_equal(x, y) and ir_pass.structural_hash(x) == ir_pass.structural_hash(y) + return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32") @@ -604,14 +604,14 @@ def test_hash_unequal(): y2 = relay.var("y2", shape=(10, 10), dtype="float32") func2 = relay.Function([x2, y2], relay.add(x2, y2)) - assert ir_pass.structural_hash(func1) == ir_pass.structural_hash(func2) + assert analysis.structural_hash(func1) == analysis.structural_hash(func2) # func3 is same as func1 but with different var shapes x3 = relay.var("x3", shape=(20, 10), dtype="float32") y3 = relay.var("y3", shape=(20, 10), dtype="float32") func3 = relay.Function([x3, y3], relay.add(x3, y3)) - assert not ir_pass.structural_hash(func1) == ir_pass.structural_hash(func3) + assert not analysis.structural_hash(func1) == analysis.structural_hash(func3) if __name__ == "__main__": test_tensor_type_alpha_equal() diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 7d022ba255708..65fd0b0819ccb 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -19,7 +19,18 @@ from tvm import relay from tvm.relay.op import register_alter_op_layout -from tvm.relay.ir_pass import * +from tvm.relay import transform, analysis + + +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = relay.Module.from_expr(expr) + seq = transform.Sequential(passes) + with transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + def test_alter_op(): """Test directly replacing an operator with a new one""" @@ -52,13 +63,10 @@ def expected(): return y a = before() - a = infer_type(a) - a = alter_op_layout(a) + a = run_opt_pass(a, transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) - b = expected() - b = infer_type(b) - - assert alpha_equal(a, b), "Actual = \n" + str(a) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_alter_return_none(): @@ -77,12 +85,11 @@ def alter_conv2d(attrs, inputs, tinfos): return None a = before() - a = infer_type(a) - a = alter_op_layout(a) + a = run_opt_pass(a, transform.AlterOpLayout()) b = before() - b = infer_type(b) - assert alpha_equal(a, b), "Actual = \n" + str(a) + b = run_opt_pass(b, transform.InferType()) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert(called[0]) @@ -102,7 +109,7 @@ def before(): y = relay.nn.max_pool2d(y, pool_size=(2, 2)) y = relay.cast(y, 'int32') y = relay.nn.batch_flatten(y) - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y @register_alter_op_layout("nn.conv2d", level=102) @@ -135,20 +142,17 @@ def expected(): y = relay.cast(y, 'int32') y = relay.layout_transform(y, "NCHW16c", "NCHW") y = relay.nn.batch_flatten(y) - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y a = before() - a = infer_type(a) - a = canonicalize_ops(a) - a = infer_type(a) - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert alpha_equal(a, b), "Actual = \n" + str(a) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_dual_path(): @@ -172,7 +176,7 @@ def before(): y1 = relay.nn.relu(y1) y2 = relay.nn.batch_flatten(y) ret = relay.Tuple([y1, y2]) - y = relay.Function(free_vars(ret), ret) + y = relay.Function(analysis.free_vars(ret), ret) return y @register_alter_op_layout("nn.conv2d", level=103) @@ -203,18 +207,16 @@ def expected(): y2 = relay.layout_transform(y, "NCHW16c", "NCHW") y2 = relay.nn.batch_flatten(y2) ret = relay.Tuple([y1, y2]) - y = relay.Function(free_vars(ret), ret) + y = relay.Function(analysis.free_vars(ret), ret) return y a = before() - a = infer_type(a) - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, transform.AlterOpLayout()) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert alpha_equal(a, b), "Actual = \n" + str(a) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_resnet(): """Test alternating the layout of a residual block @@ -236,7 +238,7 @@ def before(): y2 = relay.nn.relu(y2) y = y + y2 y = relay.nn.global_max_pool2d(y) - return relay.Function(free_vars(y), y) + return relay.Function(analysis.free_vars(y), y) @register_alter_op_layout("nn.conv2d", level=104) def alter_conv2d(attrs, inputs, tinfos): @@ -264,17 +266,15 @@ def expected(): y = y + y2 y = relay.nn.global_max_pool2d(y, layout="NCHW16c") y = relay.layout_transform(y, "NCHW16c", "NCHW") - return relay.Function(free_vars(y), y) + return relay.Function(analysis.free_vars(y), y) a = before() - a = infer_type(a) - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, transform.AlterOpLayout()) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert alpha_equal(a, b), "Actual = \n" + str(a) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_broadcast_op(): @@ -287,7 +287,7 @@ def before(): y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) y = relay.nn.bias_add(y, bias) # test broadcasting to lhs y = relay.multiply(scale, y) # test broadcasting to rhs - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y @register_alter_op_layout("nn.conv2d", level=105) @@ -311,20 +311,17 @@ def expected(): y = relay.add(y, bias) # test broadcasting to lhs y = relay.multiply(scale, y) # test broadcasting to rhs y = relay.layout_transform(y, "NCHW16c", "NCHW") - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y a = before() - a = infer_type(a) - a = canonicalize_ops(a) - a = infer_type(a) - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert alpha_equal(a, b), "Actual = \n" + str(a) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_scalar(): """Test alternating the layout of a conv2d. @@ -335,7 +332,7 @@ def before(): weight = relay.var("weight") y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) y = relay.add(y, relay.const(1, "float32")) - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y @register_alter_op_layout("nn.conv2d", level=106) @@ -358,20 +355,17 @@ def expected(): y = relay.add(y, relay.const(1.0, "float32")) y = relay.layout_transform(y, "NCHW16c", "NCHW") - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y a = before() - a = infer_type(a) - a = canonicalize_ops(a) - a = infer_type(a) - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert alpha_equal(a, b), "Actual = \n" + str(a) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_concatenate(): """ """ @@ -388,7 +382,7 @@ def before(): kernel_size=(3, 3), padding=(1, 1)) ret = relay.concatenate([y, y1], axis=1) - y = relay.Function(free_vars(ret), ret) + y = relay.Function(analysis.free_vars(ret), ret) return y @register_alter_op_layout("nn.conv2d", level=107) @@ -415,18 +409,16 @@ def expected(): data_layout='NCHW16c') ret = relay.concatenate([y, y1], axis=1) ret = relay.layout_transform(ret, "NCHW16c", "NCHW") - y = relay.Function(free_vars(ret), ret) + y = relay.Function(analysis.free_vars(ret), ret) return y a = before() - a = infer_type(a) - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, transform.AlterOpLayout()) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert alpha_equal(a, b), "Actual = \n" + str(a) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_nchw_upsamping_op(): @@ -437,7 +429,7 @@ def before(): y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) y = relay.nn.upsampling(y, scale=2) y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2)) - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y @register_alter_op_layout("nn.conv2d", level=108) @@ -456,21 +448,17 @@ def expected(): y = relay.nn.upsampling(y, scale=2, layout="NCHW16c") y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c') y = relay.layout_transform(y, "NCHW16c", "NCHW") - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y a = before() - a = infer_type(a) - a = canonicalize_ops(a) - a = infer_type(a) - - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert alpha_equal(a, b), "Actual = \n" + str(a) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_strided_slice(): @@ -480,7 +468,7 @@ def before(): weight = relay.var('weight', shape=(32, 32, 3, 3)) y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) y = relay.strided_slice(y, begin=[0, 16], end=[None, None]) - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y @register_alter_op_layout("nn.conv2d", level=109) @@ -498,21 +486,17 @@ def expected(): data_layout="NCHW4c") y = relay.strided_slice(y, begin=[0, 4], end=[None, 8]) y = relay.layout_transform(y, "NCHW4c", "NCHW") - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y a = before() - a = infer_type(a) - a = canonicalize_ops(a) - a = infer_type(a) - - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert alpha_equal(a, b), "Actual = \n" + str(a) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_depthwise_conv2d(): """Test depthwise_conv2d operator""" @@ -520,7 +504,7 @@ def before(): x = relay.var("x", shape=(1, 32, 56, 56)) w = relay.var("w", shape=(32, 1, 3, 3)) y = relay.nn.conv2d(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3), groups=32) - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y import topi @@ -538,20 +522,17 @@ def expected(): groups=32, data_layout="NCHW8c", kernel_layout="OIHW1i8o", out_layout="NCHW8c") y = relay.layout_transform(y, "NCHW8c", "NCHW") - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y a = before() - a = infer_type(a) - a = canonicalize_ops(a) - a = infer_type(a) - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert(alpha_equal(a, b)) + assert(analysis.alpha_equal(a, b)) def test_alter_layout_prelu(): """Test PRelu operator""" @@ -561,7 +542,7 @@ def before(): alpha = relay.var("alpha", relay.IncompleteType()) y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) y = relay.nn.prelu(y, alpha) - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y @register_alter_op_layout("nn.conv2d", level=111) @@ -584,20 +565,16 @@ def expected(): data_layout="NCHW16c") y = relay.layout_transform(y, "NCHW16c", "NCHW") y = relay.nn.prelu(y, alpha) - y = relay.Function(free_vars(y), y) + y = relay.Function(analysis.free_vars(y), y) return y a = before() - a = infer_type(a) - a = canonicalize_ops(a) - a = infer_type(a) - a = alter_op_layout(a) - a = infer_type(a) + a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = expected() - b = infer_type(b) + b = run_opt_pass(b, transform.InferType()) - assert(alpha_equal(a, b)) + assert(analysis.alpha_equal(a, b)) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 84a5c87490796..86ebf73d3dd69 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -22,6 +22,16 @@ from tvm import relay from tvm.contrib import graph_runtime from tvm.relay.expr_functor import ExprMutator +from tvm.relay import transform + + +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = relay.Module.from_expr(expr) + seq = transform.Sequential(passes) + with transform.PassContext(opt_level=3): + mod = seq(mod) + return mod[mod.entry_func] def test_redundant_annotation(): @@ -39,9 +49,8 @@ def annotated(): sub2 = relay.subtract(_add2, z) func = relay.Function([x, y, z], relay.Tuple([sub1, sub2])) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.rewrite_annotated_ops(func, - ctx1.device_type) + func = run_opt_pass(func, + transform.RewriteAnnotatedOps(ctx1.device_type)) return func def expected(): @@ -53,9 +62,9 @@ def expected(): func = relay.Function([x, y, z], relay.Tuple([sub1, sub2])) return func - annotated_func = relay.ir_pass.infer_type(annotated()) - expected_func = relay.ir_pass.infer_type(expected()) - assert relay.ir_pass.alpha_equal(annotated_func, expected_func) + annotated_func = annotated() + expected_func = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.alpha_equal(annotated_func, expected_func) def test_annotate_expr(): @@ -70,9 +79,8 @@ def annotated(): _add = relay.annotation.on_device(add, ctx1) sub = relay.subtract(_add, z) _sub = relay.annotation.on_device(sub, ctx2) - expr = relay.ir_pass.infer_type(_sub) - expr = relay.ir_pass.rewrite_annotated_ops(expr, - ctx1.device_type) + expr = run_opt_pass(_sub, + transform.RewriteAnnotatedOps(ctx1.device_type)) return expr def expected(): @@ -81,9 +89,9 @@ def expected(): sub = relay.subtract(copy_add_sub, z) return sub - annotated_expr = relay.ir_pass.infer_type(annotated()) - expected_expr = relay.ir_pass.infer_type(expected()) - assert relay.ir_pass.graph_equal(annotated_expr, expected_expr) + annotated_expr = annotated() + expected_expr = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.graph_equal(annotated_expr, expected_expr) def test_annotate_all(): @@ -100,9 +108,8 @@ def annotated(): _sub = relay.annotation.on_device(sub, ctx2) func = relay.Function([x, y, z], _sub) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.rewrite_annotated_ops(func, - ctx1.device_type) + func = run_opt_pass(func, + transform.RewriteAnnotatedOps(ctx1.device_type)) return func def expected(): @@ -111,9 +118,9 @@ def expected(): func = relay.Function([x, y, z], sub) return func - annotated_func = relay.ir_pass.infer_type(annotated()) - expected_func = relay.ir_pass.infer_type(expected()) - assert relay.ir_pass.alpha_equal(annotated_func, expected_func) + annotated_func = annotated() + expected_func = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.graph_equal(annotated_func, expected_func) def test_annotate_none(): @@ -127,9 +134,8 @@ def annotated(): add = relay.add(x, y) sub = relay.subtract(add, z) func = relay.Function([x, y, z], sub) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.rewrite_annotated_ops(func, - ctx1.device_type) + func = run_opt_pass(func, + transform.RewriteAnnotatedOps(ctx1.device_type)) return func def expected(): @@ -138,15 +144,15 @@ def expected(): func = relay.Function([x, y, z], sub) return func - annotated_func = relay.ir_pass.infer_type(annotated()) - expected_func = relay.ir_pass.infer_type(expected()) - assert relay.ir_pass.alpha_equal(annotated_func, expected_func) + annotated_func = annotated() + expected_func = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.graph_equal(annotated_func, expected_func) def check_annotated_graph(annotated_func, expected_func): - annotated_func = relay.ir_pass.infer_type(annotated_func) - expected_func = relay.ir_pass.infer_type(expected_func) - assert relay.ir_pass.alpha_equal(annotated_func, expected_func) + annotated_func = run_opt_pass(annotated_func, transform.InferType()) + expected_func = run_opt_pass(expected_func, transform.InferType()) + assert relay.analysis.alpha_equal(annotated_func, expected_func) def test_conv_network(): @@ -189,9 +195,8 @@ def original(): padding=(1, 1)) func = relay.Function([data1, data2, weight], conv2d_3) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.rewrite_annotated_ops(func, - tvm.context(3).device_type) + func = run_opt_pass( + func, transform.RewriteAnnotatedOps(tvm.context(3).device_type)) return func @@ -221,9 +226,8 @@ def annotated(): _conv2d_3 = relay.annotation.on_device(conv2d_3, dev2) func = relay.Function([data1, data2, weight], _conv2d_3) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.rewrite_annotated_ops(func, - tvm.context(3).device_type) + func = run_opt_pass( + func, transform.RewriteAnnotatedOps(tvm.context(3).device_type)) return func class ScheduleConv2d(ExprMutator): @@ -241,7 +245,8 @@ def visit_call(self, expr): def annotate_with_visitor(func): sched = ScheduleConv2d(dev2) func = sched.visit(func) - func = relay.ir_pass.rewrite_annotated_ops(func, dev1.device_type) + func = run_opt_pass( + func, transform.RewriteAnnotatedOps(dev1.device_type)) return func def expected(): @@ -273,10 +278,8 @@ def expected(): def check_storage_and_device_types(): func = annotated() - func = relay.ir_pass.rewrite_annotated_ops(func, 3) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.fuse_ops(func, opt_level=2) - func = relay.ir_pass.infer_type(func) + func = run_opt_pass(func, [transform.RewriteAnnotatedOps(3), + transform.FuseOps(2)]) smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = [] device_types = [] @@ -377,9 +380,8 @@ def annotated(): _exp = relay.annotation.on_device(exp, dev_ctx) func = relay.Function([x, y], _exp) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.rewrite_annotated_ops(func, - cpu_ctx.device_type) + func = run_opt_pass( + func, transform.RewriteAnnotatedOps(cpu_ctx.device_type)) return func def expected(): @@ -424,9 +426,8 @@ def annotated(): _exp = relay.annotation.on_device(exp, dev_ctx) func = relay.Function([x, y], _exp) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.rewrite_annotated_ops(func, - cpu_ctx.device_type) + func = run_opt_pass( + func, transform.RewriteAnnotatedOps(cpu_ctx.device_type)) return func annotated_func = annotated() @@ -449,9 +450,8 @@ def annotated(): _exp = relay.annotation.on_device(exp, cpu_ctx) func = relay.Function([x, y], _exp) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.rewrite_annotated_ops(func, - dev_ctx.device_type) + func = run_opt_pass( + func, transform.RewriteAnnotatedOps(dev_ctx.device_type)) return func def expected(): @@ -495,7 +495,7 @@ def run_unpropagatable_graph(dev, tgt): \ / subtract """ - + a = relay.var("a", shape=(10, 10)) b = relay.var("b", shape=(10, 10)) c = relay.var("c", shape=(10, 10)) @@ -507,13 +507,13 @@ def run_unpropagatable_graph(dev, tgt): tmp_add = a_data + b_data tmp_mul = np.multiply(c_data, d_data) ref_res = np.subtract(tmp_add, tmp_mul) - + fallback_device = tvm.context("cpu") target = {"cpu": "llvm", dev: tgt} cpu_ctx = fallback_device dev_ctx = tvm.context(dev) - - def annotated(): + + def annotated(): add = relay.add(a, b) _add = relay.annotation.on_device(add, dev_ctx) mul = relay.multiply(c, d) @@ -521,19 +521,18 @@ def annotated(): sub = relay.subtract(_add, _mul) _sub = relay.annotation.on_device(sub, dev_ctx) func = relay.Function([a, b, c, d], _sub) - func = relay.ir_pass.infer_type(func) - func = relay.ir_pass.rewrite_annotated_ops(func, - dev_ctx.device_type) + func = run_opt_pass( + func, transform.RewriteAnnotatedOps(dev_ctx.device_type)) return func - - def expected(): + + def expected(): add = relay.add(a, b) mul = relay.multiply(c, d) copy_mul_sub = relay.device_copy(mul, cpu_ctx, dev_ctx) sub = relay.subtract(add, copy_mul_sub) func = relay.Function([a, b, c, d], sub) return func - + annotated_func = annotated() expected_func = expected() expected_index = [2, 2, 2, 1, 1, 1, 2, 2] @@ -553,7 +552,7 @@ def expected(): mod.run() res = mod.get_output(0).asnumpy() tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) - + def test_check_run(): for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"), @@ -580,7 +579,7 @@ def expected(): elem0 = relay.device_copy(split[0], gpu_ctx, cpu_ctx) elem1 = relay.device_copy(split[1], gpu_ctx, cpu_ctx) sub = elem0 - elem1 - func = relay.Function(relay.ir_pass.free_vars(sub), sub) + func = relay.Function(relay.analysis.free_vars(sub), sub) return func def annotated(): @@ -590,13 +589,14 @@ def annotated(): split = relay.annotation.on_device(split, gpu_ctx) split = relay.TupleWrapper(split, 3) sub = split[0] - split[1] - func = relay.Function(relay.ir_pass.free_vars(sub), sub) - func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) + func = relay.Function(relay.analysis.free_vars(sub), sub) + func = run_opt_pass( + func, transform.RewriteAnnotatedOps(cpu_ctx.device_type)) return func - annotated_func = relay.ir_pass.infer_type(annotated()) - expected_func = relay.ir_pass.infer_type(expected()) - assert relay.ir_pass.graph_equal(annotated_func, expected_func) + annotated_func = annotated() + expected_func = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.graph_equal(annotated_func, expected_func) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_canonicalize_cast.py b/tests/python/relay/test_pass_canonicalize_cast.py index 04478e94039a2..c7b88a8dc9e3f 100644 --- a/tests/python/relay/test_pass_canonicalize_cast.py +++ b/tests/python/relay/test_pass_canonicalize_cast.py @@ -60,8 +60,11 @@ def check(shape): mod = seq(mod) y = mod[mod.entry_func.name_hint] y_expected = expected(data, conv_weight, bias1, bias2) - y_expected = relay.ir_pass.infer_type(y_expected) - assert relay.ir_pass.alpha_equal(y, y_expected) + gv = relay.GlobalVar("expected") + mod[gv] = y_expected + mod = _transform.InferType()(mod) + y_expected = mod["expected"] + assert relay.analysis.alpha_equal(y, y_expected) check((1, 16, 7, 7)) diff --git a/tests/python/relay/test_pass_check_kind.py b/tests/python/relay/test_pass_check_kind.py index 4d9a2e77eae28..7049ba6f11ed1 100644 --- a/tests/python/relay/test_pass_check_kind.py +++ b/tests/python/relay/test_pass_check_kind.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import relay -from tvm.relay.ir_pass import check_kind +from tvm.relay.analysis import check_kind from nose.tools import raises diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 3bb656b2bda56..4ea11f42f40dc 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -15,7 +15,19 @@ # specific language governing permissions and limitations # under the License. from tvm import relay -import numpy as np +from tvm.relay import transform + + +def run_combine_parallel(expr, min_num_branches=3): + mod = relay.Module.from_expr(expr) + mod = transform.CombineParallelConv2D(min_num_branches)(mod) + return mod[mod.entry_func] + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + return mod[mod.entry_func] def test_combine_parallel_conv2d(): @@ -54,12 +66,11 @@ def check(x_shape, channels1, channels2, channels3, channels4): w4 = relay.var("w4", shape=(channels4, in_c, 1, 1)) y_before = before(x, w1, w2, w3, w4) - y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) - y = relay.ir_pass.infer_type(y) + y = run_opt_pass(y_before, + transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) - y_expected = relay.ir_pass.infer_type(y_expected) - assert relay.ir_pass.alpha_equal(y, y_expected) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) check((1, 4, 16, 16), 4, 4, 4, 4) check((1, 4, 16, 16), 4, 8, 4, 7) @@ -101,12 +112,11 @@ def check(x_shape, channels1, channels2): scale2 = relay.var("scale2", shape=(channels2, 1, 1)) bias = relay.var("bias", shape=(channels2, 1, 1)) y_before = before(x, w1, w2, scale1, scale2, bias) - y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) - y = relay.ir_pass.infer_type(y) + y = run_opt_pass(y_before, + transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) - y_expected = relay.ir_pass.infer_type(y_expected) - assert relay.ir_pass.alpha_equal(y, y_expected) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) check((1, 4, 16, 16), 4, 8) @@ -141,12 +151,11 @@ def check(x_shape, channels1, channels2): scale1 = relay.var("scale1", shape=(1,)) scale2 = relay.var("scale2", shape=(1,)) y_before = before(x, w1, w2, scale1, scale2) - y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) - y = relay.ir_pass.infer_type(y) + y = run_opt_pass(y_before, + transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) - y_expected = relay.ir_pass.infer_type(y_expected) - assert relay.ir_pass.alpha_equal(y, y_expected) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) check((1, 4, 16, 16), 4, 8) @@ -178,12 +187,11 @@ def check(x_shape, repeat): out_c = in_c // 2 w = relay.var("w", shape=(out_c, in_c, 1, 1)) y_before = before(x, w, repeat) - y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) - y = relay.ir_pass.infer_type(y) + y = run_opt_pass(y_before, + transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w, out_c, repeat) - y_expected = relay.ir_pass.infer_type(y_expected) - assert relay.ir_pass.alpha_equal(y, y_expected) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) check((1, 4, 16, 16), 4) diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index c3b12fea44867..17a836beecd57 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -19,7 +19,7 @@ import tvm from tvm import relay from tvm.relay import Function, transform -from tvm.relay.ir_pass import alpha_equal, graph_equal, free_vars +from tvm.relay.analysis import alpha_equal, graph_equal, free_vars from tvm.relay.op import log, add, equal, subtract @@ -45,28 +45,36 @@ def __init__(self): e = env() +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + + def test_let(): orig = relay.Let(e.x, e.y, e.z) - orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + orig = run_opt_pass(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z)) def test_used_let(): orig = relay.Let(e.c, e.one, e.c + e.c) - orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + orig = run_opt_pass(orig, transform.DeadCodeElimination()) expected = relay.Let(e.c, e.one, e.c + e.c) assert alpha_equal(Function([e.c], orig), Function([e.c], expected)) @nottest def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) - orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + orig = run_opt_pass(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d)) def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) - orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + orig = run_opt_pass(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e)) @@ -93,17 +101,17 @@ def test_recursion(): log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)])) - dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) - orig = transform.OptimizeOnExpr(orig, transform.InferType()) + dced = run_opt_pass(orig, transform.DeadCodeElimination()) + orig = run_opt_pass(orig, transform.InferType()) assert graph_equal(dced, orig) - dced = transform.OptimizeOnExpr(relay.Let(f, value, e.three), - transform.DeadCodeElimination()) + dced = run_opt_pass(relay.Let(f, value, e.three), + transform.DeadCodeElimination()) assert alpha_equal(dced, e.three) def test_op_let(): - dced = transform.OptimizeOnExpr(add(relay.Let(e.a, e.one, e.three), e.two), - transform.DeadCodeElimination()) + dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), + transform.DeadCodeElimination()) assert alpha_equal(dced, add(e.three, e.two)) @@ -112,10 +120,10 @@ def test_tuple_get_item(): t = relay.Var('t', tt) a = relay.Var('a') g = relay.TupleGetItem(t, 0) - dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination()) + dced = run_opt_pass(g, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) - dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + dced = run_opt_pass(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py index 1ebd834d4da76..f08d0dfd1f262 100644 --- a/tests/python/relay/test_pass_eliminate_common_subexpr.py +++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py @@ -17,7 +17,15 @@ """Test eliminate common subexpr pass""" from tvm import relay from tvm.relay.op import register_alter_op_layout -from tvm.relay import ir_pass +from tvm.relay import transform, analysis + + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body def test_simple(): @@ -37,11 +45,11 @@ def expected(): y = relay.add(y, relay.const(1.0, "float32")) y = relay.add(y, y) f = relay.Function([x], y) - return f + return run_opt_pass(f, transform.InferType()) z = before() - z = ir_pass.eliminate_common_subexpr(z) - assert ir_pass.alpha_equal(z, expected()) + z = run_opt_pass(z, transform.EliminateCommonSubexpr()) + assert analysis.alpha_equal(z, expected()) def test_callback(): @@ -62,7 +70,7 @@ def expected(): y2 = relay.add(y, relay.const(1.0, "float32")) y = relay.add(y1, y2) f = relay.Function([x], y) - return f + return run_opt_pass(f, transform.InferType()) def fskip(expr): if isinstance(expr, relay.expr.Call) and expr.op.name == 'add': @@ -70,8 +78,8 @@ def fskip(expr): return False z = before() - z = ir_pass.eliminate_common_subexpr(z, fskip) - assert ir_pass.alpha_equal(z, expected()) + z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip)) + assert analysis.alpha_equal(z, expected()) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py index 4e20b02357d3d..5308e472129a6 100644 --- a/tests/python/relay/test_pass_eta_expand.py +++ b/tests/python/relay/test_pass_eta_expand.py @@ -30,10 +30,11 @@ def test_eta_expand_basic(): y = relay.var('y', 'int32') expected = relay.Function([y], orig(y)) - - got = relay.ir_pass.infer_type(got, mod) - expected = relay.ir_pass.infer_type(expected, mod) - assert(relay.ir_pass.alpha_equal(got, expected)) + gv = relay.GlobalVar("gv") + mod[gv] = expected + mod = _transform.InferType()(mod) + expected = mod["gv"] + assert(relay.analysis.alpha_equal(got, expected)) if __name__ == "__main__": test_eta_expand_basic() diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 2abeaa8f8db8b..881ec8f912c9c 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -17,13 +17,24 @@ import numpy as np import tvm from tvm import relay +from tvm.relay import transform + + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body def test_fold_const(): c_data = np.array([1, 2, 3]).astype("float32") + t = relay.TensorType([1, 2, 3], "float32") def before(): c = relay.const(c_data) - x = relay.var("x") + x = relay.var("x", t) y = relay.add(c, c) y = relay.multiply(y, relay.const(2, "float32")) y = relay.add(x, y) @@ -31,7 +42,7 @@ def before(): return relay.Function([x], z) def expected(): - x = relay.var("x") + x = relay.var("x", t) c_folded = (c_data + c_data) * 2 y = relay.add(x, relay.const(c_folded)) z = relay.add(y, relay.const(c_data)) @@ -39,19 +50,21 @@ def expected(): def fail(x): raise RuntimeError() + # the fold constant should work on any context. with tvm.build_config(add_lower_pass=[(0, fail)]): with tvm.target.create("cuda"): - zz = relay.ir_pass.fold_constant(before()) - zexpected = expected() - assert relay.ir_pass.alpha_equal(zz, zexpected) + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.alpha_equal(zz, zexpected) def test_fold_let(): c_data = np.array(1).astype("float32") + t = relay.TensorType([1], "float32") def before(): sb = relay.ScopeBuilder() - x = relay.var("x") + x = relay.var("x", t) t1 = sb.let("t1", relay.const(c_data)) t2 = sb.let("t2", relay.add(t1, t1)) t3 = sb.let("t3", relay.add(t2, x)) @@ -60,22 +73,23 @@ def before(): def expected(): sb = relay.ScopeBuilder() - x = relay.var("x") + x = relay.var("x", t) c_folded = (c_data + c_data) t3 = sb.let("t3", relay.add(relay.const(c_folded), x)) sb.ret(t3) return relay.Function([x], sb.get()) - zz = relay.ir_pass.fold_constant(before()) - zexpected = expected() - assert relay.ir_pass.graph_equal(zz, zexpected) + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.graph_equal(zz, zexpected) def test_fold_tuple(): c_data = np.array(1).astype("float32") + t = relay.TensorType([1], "float32") def before(): c = relay.const(c_data) - x = relay.var("x") + x = relay.var("x", t) y = relay.Tuple([x, c]) z = relay.add(y[1], c) z = relay.add(z, y[0]) @@ -83,13 +97,13 @@ def before(): def expected(): c = relay.const(c_data + c_data) - x = relay.var("x") + x = relay.var("x", t) z = relay.add(c, x) return relay.Function([x], z) - zz = relay.ir_pass.fold_constant(before()) - zexpected = expected() - assert relay.ir_pass.graph_equal(zz, zexpected) + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.graph_equal(zz, zexpected) def test_fold_concat(): @@ -106,9 +120,9 @@ def expected(): y = relay.const(y_data) return relay.Function([], y) - zz = relay.ir_pass.fold_constant(before()) - zexpected = expected() - assert relay.ir_pass.graph_equal(zz, zexpected) + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.graph_equal(zz, zexpected) def test_fold_shape_of(): @@ -123,17 +137,13 @@ def expected(dtype): x = relay.var("x", shape=c_shape, dtype="float32") y = relay.var("y", shape=c_shape, dtype="float32") z = relay.const(np.array(c_shape).astype(dtype), dtype=dtype) - return relay.ir_pass.infer_type(relay.Function([x, y], z)) + func = relay.Function([x, y], z) + return func for dtype in ["int32", "float32"]: - zbefore = before(dtype) - zz = relay.ir_pass.fold_constant(zbefore) - assert relay.ir_pass.graph_equal(zz, zbefore) - - zz = relay.ir_pass.infer_type(zbefore) - zz = relay.ir_pass.fold_constant(zz) - zexpected = expected(dtype) - assert relay.ir_pass.graph_equal(zz, zexpected) + zz = run_opt_pass(before(dtype), transform.FoldConstant()) + zexpected = run_opt_pass(expected(dtype), transform.InferType()) + assert relay.analysis.graph_equal(zz, zexpected) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 383f0072059f4..70354fbdaa3b8 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -14,13 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from tvm import relay import numpy as np +from tvm import relay +from tvm.relay import transform + def _get_positive_scale(size): return np.random.uniform(0.5, 1, size=size).astype('float32') +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + + def test_fold_fwd_simple(): """Simple testcase.""" def before(x, conv_weight, in_bias, in_scale, channels): @@ -59,15 +69,15 @@ def check(shape, channels): in_bias = relay.var("in_bias", shape=(in_channels,)) in_scale = relay.const(_get_positive_scale((in_channels, 1, 1))) y1 = before(x, weight, in_bias, in_scale, channels) - y1 = relay.ir_pass.infer_type(y1) + y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) - y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) y1_expected = expected(x, weight, in_bias, in_scale, channels) - y1_folded = relay.ir_pass.infer_type(y1_folded) - y1_expected = relay.ir_pass.infer_type(y1_expected) - assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + y1_folded = run_opt_pass(y1_folded, transform.InferType()) + y1_expected = run_opt_pass(y1_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 2) @@ -129,14 +139,13 @@ def check(shape, channels): in_bias = relay.var("in_bias", shape=(in_channels,)) in_scale = relay.const(_get_positive_scale(in_channels,)) y1 = before(x, weight, in_bias, in_scale, channels) - y1 = relay.ir_pass.infer_type(y1) - y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + y1 = run_opt_pass(y1, transform.InferType()) + y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_expected = expected(x, weight, in_bias, in_scale, channels) - y1_folded = relay.ir_pass.infer_type(y1_folded) - y1_expected = relay.ir_pass.infer_type(y1_expected) - assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + y1_expected = run_opt_pass(y1_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 3), 3) @@ -152,7 +161,7 @@ def before(x, conv_weight, in_bias, in_scale, channels): data_layout="NHWC", padding=(1, 1)) z = relay.add(y1, x) - return relay.Function(relay.ir_pass.free_vars(z), z) + return relay.Function(relay.analysis.free_vars(z), z) def check(shape, channels): x = relay.var("x", shape=shape) @@ -163,9 +172,9 @@ def check(shape, channels): in_bias = relay.var("in_bias", shape=(in_channels,)) in_scale = relay.const(_get_positive_scale(size=(in_channels,))) y1 = before(x, weight, in_bias, in_scale, channels) - y1 = relay.ir_pass.infer_type(y1) - y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) - assert relay.ir_pass.alpha_equal(y1, y1_folded) + y1 = run_opt_pass(y1, transform.InferType()) + y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) + assert relay.analysis.alpha_equal(y1, y1_folded) check((2, 11, 10, 4), 4) @@ -181,7 +190,7 @@ def before(x, conv_weight, in_bias, in_scale, channels): data_layout="NHWC", padding=(1, 1)) z = relay.add(y1, x) - return relay.Function(relay.ir_pass.free_vars(z), z) + return relay.Function(relay.analysis.free_vars(z), z) def check(shape, channels, in_scale): x = relay.var("x", shape=shape) @@ -191,9 +200,9 @@ def check(shape, channels, in_scale): weight = relay.var("weight") in_bias = relay.var("in_bias", shape=(in_channels,)) y1 = before(x, weight, in_bias, in_scale, channels) - y1 = relay.ir_pass.infer_type(y1) - y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) - assert relay.ir_pass.alpha_equal(y1, y1_folded) + y1 = run_opt_pass(y1, transform.InferType()) + y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) + assert relay.analysis.alpha_equal(y1, y1_folded) in_scale = relay.var("in_scale", shape=(4,)) check((2, 11, 10, 4), 4, in_scale) @@ -231,14 +240,13 @@ def check(shape, channels): in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1))) weight = relay.var("weight") y1 = before(x, weight, in_scale, channels) - y1 = relay.ir_pass.infer_type(y1) + y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) - y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) y1_expected = expected(x, weight, in_scale, channels) - y1_folded = relay.ir_pass.infer_type(y1_folded) - y1_expected = relay.ir_pass.infer_type(y1_expected) - assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + y1_expected = run_opt_pass(y1_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 4) @@ -283,14 +291,13 @@ def check(shape, channels): out_scale = relay.const(_get_positive_scale((channels, 1, 1))) y1 = before(x, weight, out_bias, out_scale, channels) - y1 = relay.ir_pass.infer_type(y1) + y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) - y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) + y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, channels) - y1_folded = relay.ir_pass.infer_type(y1_folded) - y1_expected = relay.ir_pass.infer_type(y1_expected) - assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + y1_expected = run_opt_pass(y1_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 8) @@ -343,14 +350,13 @@ def check(shape, channels): out_scale = relay.const(_get_positive_scale((channels, 1, 1))) y1 = before(x, weight, out_bias, out_scale, channels) - y1 = relay.ir_pass.infer_type(y1) + y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) - y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) + y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, channels) - y1_folded = relay.ir_pass.infer_type(y1_folded) - y1_expected = relay.ir_pass.infer_type(y1_expected) - assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + y1_expected = run_opt_pass(y1_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 8) @@ -416,14 +422,13 @@ def check(shape, channels): out_scale = relay.const(_get_positive_scale((channels,1, 1))) y1 = before(x, weight, out_bias, out_scale, channels) - y1 = relay.ir_pass.infer_type(y1) + y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) - y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) + y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, channels) - y1_folded = relay.ir_pass.infer_type(y1_folded) - y1_expected = relay.ir_pass.infer_type(y1_expected) - assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + y1_expected = run_opt_pass(y1_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 4) @@ -470,9 +475,9 @@ def check(shape, channels, fbefore): out_bias = relay.var("out_bias", shape=(channels,)) out_scale = relay.const(_get_positive_scale((channels, 1, 1))) y1 = fbefore(x, weight, out_bias, out_scale, channels) - y1 = relay.ir_pass.infer_type(y1) - y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) - assert relay.ir_pass.alpha_equal(y1_folded, y1) + y1 = run_opt_pass(y1, transform.InferType()) + y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) + assert relay.analysis.alpha_equal(y1_folded, y1) check((4, 4, 10, 10), 4, fail1) check((4, 4, 10, 10), 4, fail2) @@ -488,16 +493,16 @@ def before(x, conv_weight, out_scale, channels): padding=(1, 1)) y = relay.nn.relu(y) y = relay.multiply(x, out_scale) - return relay.Function(relay.ir_pass.free_vars(y), y) + return relay.Function(relay.analysis.free_vars(y), y) def check(shape, channels, out_scale): x = relay.var("x", shape=shape) in_channels = shape[1] weight = relay.var("weight") y1 = before(x, weight, out_scale, channels) - y1 = relay.ir_pass.infer_type(y1) - y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) - assert relay.ir_pass.alpha_equal(y1, y1_folded) + y1 = run_opt_pass(y1, transform.InferType()) + y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) + assert relay.analysis.alpha_equal(y1, y1_folded) out_scale = relay.var("in_scale", shape=(4, 1, 1)) check((4, 4, 10, 10), 4, out_scale) @@ -533,14 +538,13 @@ def check(shape, channels): weight = relay.var("weight") out_scale = relay.const(-_get_positive_scale((channels, 1, 1))) y1 = before(x, weight, out_scale, channels) - y1 = relay.ir_pass.infer_type(y1) + y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) - y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) + y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_scale, channels) - y1_folded = relay.ir_pass.infer_type(y1_folded) - y1_expected = relay.ir_pass.infer_type(y1_expected) - assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + y1_expected = run_opt_pass(y1_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 8) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 0ecbfe6b4d4ac..8d358e3f805f4 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -16,6 +16,16 @@ # under the License. import tvm from tvm import relay +from tvm.relay import transform + + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + def test_fuse_simple(): """Simple testcase.""" @@ -37,13 +47,10 @@ def expected(): return relay.Function([x], y) z = before() - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - zz = relay.ir_pass.fuse_ops(zz) - zz = relay.ir_pass.infer_type(zz) - after = relay.ir_pass.infer_type(expected()) - assert relay.ir_pass.alpha_equal(zz, after) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) + zz = run_opt_pass(z, transform.FuseOps()) + after = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.alpha_equal(zz, after) def test_conv2d_fuse(): @@ -69,7 +76,7 @@ def before(dshape): channels=16) # add can only be fused to z1 z = relay.add(z2, z3) - return relay.Function(relay.ir_pass.free_vars(z), z) + return relay.Function(relay.analysis.free_vars(z), z) def expected(dshape): # segment 0 @@ -111,15 +118,13 @@ def expected(dshape): z2 = relay.Call(f2, [y, relay.var("w3")]) z3 = relay.Call(f3, [y, relay.var("w2"), z2]) z = z3 - return relay.Function(relay.ir_pass.free_vars(z), z) + return relay.Function(relay.analysis.free_vars(z), z) dshape = (1, 16, 64, 64) z = before(dshape) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(zz, after) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) + after = run_opt_pass(expected(dshape), transform.InferType()) + assert relay.analysis.alpha_equal(zz, after) def test_concatenate(): @@ -131,7 +136,7 @@ def before(dshape): upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") concat = relay.concatenate((upsampled, x), axis=1) out = relay.add(concat, relay.const(1, "float32")) - return relay.Function(relay.ir_pass.free_vars(out), out) + return relay.Function(relay.analysis.free_vars(out), out) def expected(dshape): x = relay.var("x", shape=dshape) @@ -152,14 +157,12 @@ def expected(dshape): dshape = (1, 16, 64, 64) z = before(dshape) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) - after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(zz, after) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) + assert not relay.analysis.free_vars(zz) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) + assert not relay.analysis.free_vars(zz) + after = run_opt_pass(expected(dshape), transform.InferType()) + assert relay.analysis.alpha_equal(zz, after) def test_tuple_root(): @@ -170,7 +173,7 @@ def before(dshape): pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") out = relay.Tuple((upsampled, x)) - return relay.Function(relay.ir_pass.free_vars(out), out) + return relay.Function(relay.analysis.free_vars(out), out) def expected(dshape): x = relay.var("x", shape=dshape) @@ -189,15 +192,12 @@ def expected(dshape): dshape = (1, 16, 64, 64) z = before(dshape) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) - after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(zz, after) - + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) + assert not relay.analysis.free_vars(zz) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) + assert not relay.analysis.free_vars(zz) + after = run_opt_pass(expected(dshape), transform.InferType()) + assert relay.analysis.alpha_equal(zz, after) def test_stop_fusion(): @@ -224,11 +224,9 @@ def expected(dshape): dshape = (10, 20) z = before(dshape) - z = relay.ir_pass.infer_type(z) - z = relay.ir_pass.fuse_ops(z) - z = relay.ir_pass.infer_type(z) - after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(z, after) + zz = run_opt_pass(z, transform.FuseOps()) + after = run_opt_pass(expected(dshape), transform.InferType()) + assert relay.analysis.alpha_equal(zz, after) def test_fuse_myia_regression(): @@ -261,10 +259,9 @@ def expected(dshape, dtype): dshape = () dtype = 'int64' f = before(dshape, dtype) - f = relay.ir_pass.infer_type(f) - f = relay.ir_pass.fuse_ops(f) - after = relay.ir_pass.infer_type(expected(dshape, dtype)) - assert relay.ir_pass.alpha_equal(f, after) + zz = run_opt_pass(f, transform.FuseOps()) + after = run_opt_pass(expected(dshape, dtype), transform.InferType()) + assert relay.analysis.alpha_equal(zz, after) def test_fuse_tuple_get_elemwise(): @@ -295,14 +292,12 @@ def expected(dim): dim = 10 z = before(dim) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) - after = relay.ir_pass.infer_type(expected(dim)) - assert relay.ir_pass.alpha_equal(zz, after) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) + assert not relay.analysis.free_vars(zz) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) + assert not relay.analysis.free_vars(zz) + after = run_opt_pass(expected(dim), transform.InferType()) + assert relay.analysis.alpha_equal(zz, after) def test_tuple_get_root(): @@ -332,14 +327,12 @@ def expected(dim): dim = 10 z = before(dim) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) - after = relay.ir_pass.infer_type(expected(dim)) - assert relay.ir_pass.alpha_equal(zz, after) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) + assert not relay.analysis.free_vars(zz) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) + assert not relay.analysis.free_vars(zz) + after = run_opt_pass(expected(dim), transform.InferType()) + assert relay.analysis.alpha_equal(zz, after) fuse0 = relay.transform.FuseOps(fuse_opt_level=0) @@ -356,7 +349,7 @@ def before(x): concat = relay.concatenate((y1, y2, y3), axis=1) out_inj = relay.squeeze(concat) out = relay.add(out_inj, relay.const(1, "float32")) - return relay.Function(relay.ir_pass.free_vars(out), out) + return relay.Function(relay.analysis.free_vars(out), out) def expected(p0): f0 = before(p0) @@ -370,8 +363,8 @@ def expected(p0): fuse0(relay.Module.from_expr(orig)) m = fuse2(relay.Module.from_expr(orig)) relay.build(m, 'llvm') - after = relay.ir_pass.infer_type(expected(x)) - assert relay.ir_pass.alpha_equal(m[m.entry_func], after) + after = run_opt_pass(expected(x), transform.InferType()) + assert relay.analysis.alpha_equal(m[m.entry_func], after) def test_tuple_consecutive(): @@ -396,7 +389,7 @@ def before(x): out = relay.add(pooled, relay.const(1, "float32")) out2 = relay.add(out, relay.const(1, "float32")) out_tup = relay.Tuple((out, out2)) - return relay.Function(relay.ir_pass.free_vars(out_tup), out_tup) + return relay.Function(relay.analysis.free_vars(out_tup), out_tup) def expected(dshape): p0 = relay.var("p0", shape=dshape) @@ -425,8 +418,8 @@ def expected(dshape): fuse0(relay.Module.from_expr(orig)) m = fuse2(relay.Module.from_expr(orig)) relay.build(m, 'llvm') - after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(m[m.entry_func], after) + after = run_opt_pass(expected(dshape), transform.InferType()) + assert relay.analysis.alpha_equal(m[m.entry_func], after) def test_inception_like(): @@ -446,16 +439,16 @@ def before(dshape): x = relay.var("x", shape=dshape) in1 = inception_like(x) in2 = inception_like(in1) - return relay.Function(relay.ir_pass.free_vars(in2), in2) + return relay.Function(relay.analysis.free_vars(in2), in2) def expected(dshape): p0 = relay.var("p0", shape=dshape) c = conv(p0) - f0 = relay.Function(relay.ir_pass.free_vars(c), c) + f0 = relay.Function(relay.analysis.free_vars(c), c) p01 = relay.var("p01", shape=dshape) c = conv(p01) - f1 = relay.Function(relay.ir_pass.free_vars(c), c) + f1 = relay.Function(relay.analysis.free_vars(c), c) p02 = relay.var("p02", shape=dshape) p12 = relay.var("p12", shape=dshape) @@ -466,11 +459,11 @@ def expected(dshape): p03 = relay.var("p03", shape=dshape2) c = conv(p03) - f2 = relay.Function(relay.ir_pass.free_vars(c), c) + f2 = relay.Function(relay.analysis.free_vars(c), c) p04 = relay.var("p04", shape=dshape2) c = conv(p04) - f3 = relay.Function(relay.ir_pass.free_vars(c), c) + f3 = relay.Function(relay.analysis.free_vars(c), c) p05 = relay.var("p05", shape=dshape) p15 = relay.var("p15", shape=dshape) @@ -485,15 +478,15 @@ def expected(dshape): c4 = relay.Call(f3, [concat, relay.var("w4")]) out = relay.Call(f_concat2, [c3, c4]) - return relay.Function(relay.ir_pass.free_vars(out), out) + return relay.Function(relay.analysis.free_vars(out), out) dshape = (1, 16, 64, 64) orig = before(dshape) fuse0(relay.Module.from_expr(orig)) m = fuse2(relay.Module.from_expr(orig)) relay.build(m, 'llvm') - after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(m[m.entry_func], after) + after = run_opt_pass(expected(dshape), transform.InferType()) + assert relay.analysis.alpha_equal(m[m.entry_func], after) def test_fuse_parallel_injective(): @@ -518,14 +511,12 @@ def expected(): return relay.Function([x], y) z = before() - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) - after = relay.ir_pass.infer_type(expected()) - assert relay.ir_pass.alpha_equal(zz, after) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) + assert not relay.analysis.free_vars(zz) + zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) + assert not relay.analysis.free_vars(zz) + after = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.alpha_equal(zz, after) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 6fece1b0a6ddf..400f5d79b1e40 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -14,14 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm from tvm import relay -from tvm.relay.ir_pass import free_vars, free_type_vars, gradient -from tvm.relay import create_executor +from tvm.relay.analysis import free_vars, free_type_vars +from tvm.relay import create_executor, transform +from tvm.relay.transform import gradient from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, make_nat_expr -import numpy as np + +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = relay.Module.from_expr(expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body def rand(dtype='float32', *shape): @@ -34,7 +43,7 @@ def test_id(): t = relay.TensorType(shape, dtype) x = relay.var("x", t) func = relay.Function([x], x) - back_func = relay.ir_pass.infer_type(gradient(func)) + back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) ex = create_executor() x = rand(dtype, *shape) @@ -49,7 +58,7 @@ def test_add(): t = relay.TensorType(shape, dtype) x = relay.var("x", t) func = relay.Function([x], x + x) - back_func = relay.ir_pass.infer_type(gradient(func)) + back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) ex = create_executor() x = rand(dtype, *shape) @@ -65,7 +74,7 @@ def test_temp_add(): x = relay.var("x", t) y = x + x func = relay.Function([x], y + y) - back_func = relay.ir_pass.infer_type(gradient(func)) + back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) ex = create_executor() x = rand(dtype, *shape) @@ -80,7 +89,7 @@ def test_sub(): t = relay.TensorType(shape, dtype) x = relay.var("x", t) func = relay.Function([x], x - x) - back_func = relay.ir_pass.infer_type(gradient(func)) + back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) ex = create_executor() x = rand(dtype, *shape) @@ -103,7 +112,7 @@ def test_broadcast_add(): x = relay.var("x", t1) y = relay.var("y", t2) func = relay.Function([x, y], x + y) - full_func = relay.ir_pass.infer_type(gradient(func)) + full_func = run_infer_type(gradient(func)) assert full_func.checked_type == relay.FuncType([t1, t2], relay.TupleType([relay.TensorType(expected_forward.shape, dtype), relay.TupleType([t1, t2])])) @@ -130,7 +139,7 @@ def test_broadcast_subtract(): x = relay.var("x", t1) y = relay.var("y", t2) func = relay.Function([x, y], x - y) - full_func = relay.ir_pass.infer_type(gradient(func)) + full_func = run_infer_type(gradient(func)) assert full_func.checked_type == relay.FuncType([t1, t2], relay.TupleType([relay.TensorType(expected_forward.shape, dtype), relay.TupleType([t1, t2])])) @@ -155,7 +164,7 @@ def test_tuple(): relay.TupleGetItem(tup, 0) + relay.TupleGetItem(tup, 1) - relay.TupleGetItem(tup, 2))) - back_func = relay.ir_pass.infer_type(gradient(func)) + back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])])) x_nd = rand(dtype, *shape) y_nd = rand(dtype, *shape) @@ -183,7 +192,10 @@ def test_pow(): double = relay.Function([x], x + x) i = relay.var("i", t) func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) - back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod) + func = gradient(func, mod=mod) + mod[mod.entry_func] = func + m = transform.InferType()(mod) + back_func = m[m.entry_func] assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) i_nd = rand(dtype, *shape) ex = create_executor(mod=mod) @@ -203,7 +215,7 @@ def test_ref(): body = relay.Let(u, relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)), body) body = relay.Let(r, relay.RefCreate(x), body) func = relay.Function([x], body) - back_func = relay.ir_pass.infer_type(gradient(func)) + back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) x_nd = rand(dtype, *shape) ex = create_executor() @@ -218,11 +230,11 @@ def test_square_second_order(): t = relay.TensorType(shape, dtype) x = relay.var("x", t) func = relay.Function([x], x * x) - back_func = relay.ir_pass.infer_type(gradient(func)) + back_func = run_infer_type(gradient(func)) y = relay.var("y", t) back_func_adjusted = relay.Function([y], relay.TupleGetItem(relay.TupleGetItem(back_func(y), 1), 0)) - back_func_adjusted = relay.ir_pass.infer_type(back_func_adjusted) - back_back_func = relay.ir_pass.infer_type(gradient(back_func_adjusted)) + back_func_adjusted = run_infer_type(back_func_adjusted) + back_back_func = run_infer_type(gradient(back_func_adjusted)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) x_nd = rand(dtype, *shape) ex = create_executor() @@ -237,8 +249,10 @@ def test_if(): cond = relay.var("cond", shape=(), dtype='uint1') net = relay.If(cond, x, y) net = relay.log(net) - net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net)) - back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order')) + func = relay.Function(free_vars(net), net) + net = run_infer_type(func) + net = gradient(net, mode='higher_order') + net = run_infer_type(net) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_mac_count.py b/tests/python/relay/test_pass_mac_count.py index a7739a6444733..e68c748d1bb18 100644 --- a/tests/python/relay/test_pass_mac_count.py +++ b/tests/python/relay/test_pass_mac_count.py @@ -18,6 +18,16 @@ import numpy as np import tvm from tvm import relay +from tvm.relay import analysis, transform + + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + def test_gemm(): n = 512 @@ -30,8 +40,8 @@ def test_gemm(): gemm = relay.nn.dense(data1, data2) func = relay.Function([data1, data2], relay.Tuple(tvm.convert([gemm]))) - func = relay.ir_pass.infer_type(func) - compute_count = relay.ir_pass.get_total_mac_number(func) + func = run_opt_pass(func, transform.InferType()) + compute_count = analysis.get_total_mac_number(func) expect_count = n * m * k assert compute_count == expect_count @@ -56,10 +66,9 @@ def test_conv(): channels=output_channel, kernel_size=(kh, kw), padding=(h_padding, w_padding)) - func = relay.Function([data, weight], - relay.Tuple(tvm.convert([conv2d]))) - func = relay.ir_pass.infer_type(func) - compute_count = relay.ir_pass.get_total_mac_number(func) + func = relay.Function([data, weight], relay.Tuple(tvm.convert([conv2d]))) + func = run_opt_pass(func, transform.InferType()) + compute_count = analysis.get_total_mac_number(func) expect_count = batch_size * input_channel * oh * ow * output_channel * kh * kw assert compute_count == expect_count @@ -92,11 +101,9 @@ def test_simple_network(): func = relay.Function([data1, data2, weight_conv, weight_dense], relay.Tuple(tvm.convert([conv2d_1, conv2d_2, dense_1, add, flattened]))) - func = relay.ir_pass.infer_type(func) # alter the CONV 2D data layout to test - func = relay.ir_pass.alter_op_layout(func) - func = relay.ir_pass.infer_type(func) - compute_count = relay.ir_pass.get_total_mac_number(func) + func = run_opt_pass(func, transform.AlterOpLayout()) + compute_count = analysis.get_total_mac_number(func) expect_count = 231411712 assert compute_count == expect_count @@ -123,8 +130,8 @@ def test_depthwise_conv2d(): relay.Tuple(tvm.convert([depthwise_conv2d_1, depthwise_conv2d_2, add]))) - func = relay.ir_pass.infer_type(func) - compute_count = relay.ir_pass.get_total_mac_number(func) + func = run_opt_pass(func, transform.InferType()) + compute_count = analysis.get_total_mac_number(func) assert compute_count == 2 * np.prod(dshape) * 3*3 def test_conv_2d_transpose(): @@ -150,8 +157,8 @@ def test_conv_2d_transpose(): padding=(h_padding, w_padding)) func = relay.Function([data, weight], relay.Tuple(tvm.convert([conv2d_transpose]))) - func = relay.ir_pass.infer_type(func) - compute_count = relay.ir_pass.get_total_mac_number(func) + func = run_opt_pass(func, transform.InferType()) + compute_count = analysis.get_total_mac_number(func) expect_count = batch_size * input_channel * oh * ow * output_channel * kh * kw assert compute_count == expect_count diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index a8f50bdb8f558..930dbe0451983 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -21,11 +21,18 @@ from tvm import relay from tvm.relay import ExprFunctor from tvm.relay import Function, Call -from tvm.relay import ir_pass +from tvm.relay import analysis from tvm.relay import transform as _transform from tvm.relay.testing import ctx_list +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = _transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + + def get_var_func(): shape = (5, 10) tp = relay.TensorType(shape, "float32") @@ -107,9 +114,9 @@ def get_rand(shape, dtype='float32'): def check_func(func, ref_func): - func = ir_pass.infer_type(func) - ref_func = ir_pass.infer_type(ref_func) - assert ir_pass.graph_equal(func, ref_func) + func = run_infer_type(func) + ref_func = run_infer_type(ref_func) + assert analysis.graph_equal(func, ref_func) def test_module_pass(): @@ -493,8 +500,8 @@ def expected(): mod = seq(mod) zz = mod["main"] - zexpected = ir_pass.infer_type(expected()) - assert relay.ir_pass.alpha_equal(zz, zexpected) + zexpected = run_infer_type(expected()) + assert analysis.alpha_equal(zz, zexpected) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index f2aedd1905d4e..6a7f59c91daa8 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -18,12 +18,13 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import alpha_equal, gradient +from tvm.relay.analysis import alpha_equal from tvm.relay.prelude import Prelude from tvm.relay import op, create_executor, transform from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match from tvm.relay import GlobalVar, Call +from tvm.relay.transform import gradient from tvm.relay.testing import add_nat_definitions, make_nat_expr def check_eval(expr, expected_result, mod=None, rtol=1e-07): @@ -34,11 +35,19 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07): np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = relay.Module.from_expr(expr) + seq = transform.Sequential(passes) + with transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + + def tipe(expr): - return transform.OptimizeOnExpr(expr, - [transform.InferType(), - transform.PartialEvaluate(), - transform.InferType()]) + return run_opt_pass(expr, [transform.PartialEvaluate(), + transform.InferType()]) def dcpe(expr, mod=None, grad=False): @@ -52,7 +61,7 @@ def dcpe(expr, mod=None, grad=False): seq = transform.Sequential(passes) mod = seq(mod) return mod[mod.entry_func] - return transform.OptimizeOnExpr(expr, passes) + return run_opt_pass(expr, passes) def test_tuple(): @@ -61,7 +70,7 @@ def test_tuple(): body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = Function([x], body, None, [t]) expected = relay.Function([x], x, None, [t]) - expected = transform.OptimizeOnExpr(expected, transform.InferType()) + expected = run_opt_pass(expected, transform.InferType()) assert alpha_equal(dcpe(f), expected) @@ -82,8 +91,7 @@ def test_ref(): body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) - expected = transform.OptimizeOnExpr(Function([d], d * d), - transform.InferType()) + expected = run_opt_pass(Function([d], d * d), transform.InferType()) assert alpha_equal(dcpe(square), expected) @@ -95,7 +103,7 @@ def test_empty_ad(): f = Function([d], d) g = dcpe(f, grad=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) - expected = transform.OptimizeOnExpr(expected, transform.InferType()) + expected = run_opt_pass(expected, transform.InferType()) assert alpha_equal(g, expected) @@ -114,7 +122,7 @@ def test_ad(): body = Tuple([x, Tuple([grad])]) body = relay.Let(x1, o, body) expected = Function([d], relay.Let(x, m, body)) - expected = transform.OptimizeOnExpr(expected, transform.InferType()) + expected = run_opt_pass(expected, transform.InferType()) assert alpha_equal(g, expected) diff --git a/tests/python/relay/test_pass_quantize.py b/tests/python/relay/test_pass_quantize.py index fe62c3b5cea4f..21aa02df7f3a4 100644 --- a/tests/python/relay/test_pass_quantize.py +++ b/tests/python/relay/test_pass_quantize.py @@ -19,10 +19,18 @@ import tvm from tvm import relay from tvm.relay import quantize as qtz +from tvm.relay import transform + + +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body def make_dataset(graph, size=100): - args = relay.ir_pass.infer_type(graph).params + args = run_infer_type(graph).params def create_arr(var): ttype = var.type_annotation np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype) @@ -40,7 +48,7 @@ def create_arr(var): def test_simulated_quantize(): data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32")) out = qtz._annotate.attach_simulated_quantize(data, 1) - out = relay.ir_pass.infer_type(out) + out = run_infer_type(out) assert out.checked_type == out.args[0].checked_type assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32") assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "float32") @@ -59,7 +67,7 @@ def quantize_weight(arr): def make_graph(data): weight = relay.var("conv_weight") out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c) - out = relay.Function(relay.ir_pass.free_vars(out), out) + out = relay.Function(relay.analysis.free_vars(out), out) return out def make_qgraph(data, weight): @@ -72,7 +80,7 @@ def make_qgraph(data, weight): padding=(1, 1), channels=c, out_dtype='int32') out = out.astype('float32') out = relay.multiply(out, relay.const(0.00024414062)) - out = relay.Function(relay.ir_pass.free_vars(out), out) + out = relay.Function(relay.analysis.free_vars(out), out) return out np.random.seed(42) @@ -84,11 +92,11 @@ def make_qgraph(data, weight): with qtz.qconfig(skip_conv_layers=None, global_scale=4.0, round_for_shift=False, store_lowbit_output=False): qgraph0 = qtz.quantize(graph, params) - qgraph0 = relay.ir_pass.infer_type(qgraph0) + qgraph0 = run_infer_type(qgraph0) conv_weight = quantize_weight(params['conv_weight']) qgraph1 = make_qgraph(data, conv_weight) - qgraph1 = relay.ir_pass.infer_type(qgraph1) + qgraph1 = run_infer_type(qgraph1) graph = relay.create_executor('graph') res0 = graph.evaluate(qgraph0)(dataset[0]['data']) diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py index aad1d9fc6cf5f..4e62fa6dcb08e 100644 --- a/tests/python/relay/test_pass_simplify_inference.py +++ b/tests/python/relay/test_pass_simplify_inference.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from tvm import relay as rly -from tvm.relay.ir_pass import simplify_inference, alpha_equal +from tvm.relay.transform import SimplifyInference def test_simplify_batchnorm(dtype='float32'): def simple_bn(x, gamma, beta, moving_mean, moving_var, @@ -49,10 +49,13 @@ def check(dim, axis, nstep): y2 = simple_bn(y2 + rly.const(1, dtype), gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis, shape=ttype1.shape) - y1 = rly.ir_pass.infer_type(y1) - y1 = simplify_inference(y1) - assert rly.ir_pass.graph_equal(y1, y2) + mod = rly.Module.from_expr(y1) + simplify = SimplifyInference() + mod = simplify(mod) + y1 = mod["main"].body + + assert rly.analysis.graph_equal(y1, y2) check(2, 1, 1) check(4, 1, 1) diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index e74168141e63c..c12298e465df3 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -17,13 +17,23 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import alpha_equal, detect_feature +from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count from tvm.relay.feature import Feature +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = relay.Module.from_expr(expr) + seq = transform.Sequential(passes) + with transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + + def check_eval(expr, expected_result, mod=None, rtol=1e-07): ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") @@ -38,7 +48,7 @@ def test_explicit_bound(): z = op.add(y, y) f = relay.Function([], op.add(z, z)) assert not Feature.fLet in detect_feature(f) - anf = transform.OptimizeOnExpr(f, transform.ToANormalForm()) + anf = run_opt_pass(f, transform.ToANormalForm()) assert Feature.fLet in detect_feature(anf) check_eval(f(), 8.0) check_eval(anf(), 8.0) @@ -52,8 +62,7 @@ def test_order(): x = relay.const(1) val = x + y * z check_eval(val, 7.0) - anf = transform.OptimizeOnExpr(val, [transform.ToANormalForm(), - transform.InferType()]) + anf = run_opt_pass(val, [transform.ToANormalForm(), transform.InferType()]) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) @@ -65,16 +74,14 @@ def test_order(): expected_output = relay.Let(c, z, expected_output) expected_output = relay.Let(b, y, expected_output) expected_output = relay.Let(a, x, expected_output) - expected_output = transform.OptimizeOnExpr(expected_output, - transform.InferType()) + expected_output = run_opt_pass(expected_output, transform.InferType()) assert alpha_equal(anf, expected_output) def test_if(): cond = relay.const(True) x = relay.If(cond, relay.const(2), relay.const(3)) - anf = transform.OptimizeOnExpr(x, [transform.ToANormalForm(), - transform.InferType()]) + anf = run_opt_pass(x, [transform.ToANormalForm(), transform.InferType()]) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) @@ -84,8 +91,7 @@ def test_if(): expected_output = relay.If(c, true_branch, false_branch) expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(c, cond, expected_output) - expected_output = transform.OptimizeOnExpr(expected_output, - transform.InferType()) + expected_output = run_opt_pass(expected_output, transform.InferType()) assert alpha_equal(anf, expected_output) @@ -133,7 +139,7 @@ def test_ref(): body = relay.Let(iv, relay.RefRead(i), body) body = relay.Let(i, relay.RefCreate(relay.const(1)), body) check_eval(body, 3) - opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm()) + opt_body = run_opt_pass(body, transform.ToANormalForm()) check_eval(opt_body, 3) @@ -165,7 +171,7 @@ def test_let(): body = relay.Let(y, x, x + y) body = relay.Let(x, d, body) check_eval(body, 8) - opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm()) + opt_body = run_opt_pass(body, transform.ToANormalForm()) check_eval(opt_body, 8) @@ -174,7 +180,7 @@ def test_function(): x = relay.Var("x", t) f = relay.Function([x], x + x) d = relay.const(4.0, 'float32') - anf_f = transform.OptimizeOnExpr(f, transform.ToANormalForm()) + anf_f = run_opt_pass(f, transform.ToANormalForm()) assert isinstance(anf_f, relay.Function) check_eval(f(d), 8) check_eval(anf_f(d), 8) diff --git a/tests/python/relay/test_pass_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py index 09db48f633d91..9e8c5887ac582 100644 --- a/tests/python/relay/test_pass_to_graph_normal_form.py +++ b/tests/python/relay/test_pass_to_graph_normal_form.py @@ -17,9 +17,15 @@ import numpy as np import tvm from tvm import relay -from tvm.relay import op, create_executor, transform -from tvm.relay.ir_pass import detect_feature -from tvm.relay.feature import Feature +from tvm.relay import op, create_executor, transform, Feature +from tvm.relay.analysis import detect_feature + + +def run_opt_pass(expr, opt_pass): + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): @@ -40,7 +46,7 @@ def test_implicit_share(): body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) - g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm()) + g = run_opt_pass(f, transform.ToGraphNormalForm()) assert Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) @@ -54,8 +60,8 @@ def test_round_trip(): body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) - g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm()) - h = transform.OptimizeOnExpr(g, transform.ToANormalForm()) + g = run_opt_pass(f, transform.ToGraphNormalForm()) + h = run_opt_pass(g, transform.ToANormalForm()) assert Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index 4f2bb20ad7d68..776f5a05722d1 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -18,7 +18,7 @@ import tvm from tvm import relay from tvm.relay.prelude import Prelude -from tvm.relay.ir_pass import unmatched_cases +from tvm.relay.analysis import unmatched_cases def test_empty_match_block(): # empty match block will not match anything, so it should return a wildcard pattern diff --git a/tests/python/relay/test_pass_vars.py b/tests/python/relay/test_pass_vars.py index 2f1ef36e7878f..70eb047ad03ea 100644 --- a/tests/python/relay/test_pass_vars.py +++ b/tests/python/relay/test_pass_vars.py @@ -16,9 +16,9 @@ # under the License. import tvm from tvm import relay -from tvm.relay.ir_pass import (free_vars, free_type_vars, - bound_vars, bound_type_vars, - all_vars, all_type_vars) +from tvm.relay.analysis import (free_vars, free_type_vars, + bound_vars, bound_type_vars, + all_vars, all_type_vars) def assert_vars_match(actual, expected): assert len(actual) == len(expected) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 8e047354fafd7..29b79283a1fc7 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -17,16 +17,34 @@ """Test that type checker correcly computes types for expressions. """ -import tvm -import numpy as np -from tvm.relay.ir_pass import infer_type from tvm import relay -from tvm.relay import op -from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay import op, transform, analysis + + +def run_infer_type(expr, mod=None): + if not mod: + mod = relay.Module.from_expr(expr) + mod = transform.InferType()(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + else: + if isinstance(expr, relay.GlobalVar): + gv = expr.name_hint + else: + func = expr + if not isinstance(expr, relay.Function): + func = relay.Function(analysis.free_vars(expr), expr) + mod[mod.entry_func] = func + gv = "main" + mod = transform.InferType()(mod) + + if isinstance(expr, (relay.GlobalVar, relay.Function)): + return mod[gv] + return mod[gv].body def assert_has_type(expr, typ, mod=relay.module.Module({})): - checked_expr = infer_type(expr, mod) + checked_expr = run_infer_type(expr, mod) checked_type = checked_expr.checked_type if checked_type != typ: raise RuntimeError("Type mismatch %s vs %s" % ( @@ -48,7 +66,7 @@ def test_monomorphic_let(): sb = relay.ScopeBuilder() x = sb.let('x', relay.const(1.0, "float64")) sb.ret(x) - xchecked = relay.ir_pass.infer_type(sb.get()) + xchecked = run_infer_type(sb.get()) assert xchecked.checked_type == relay.scalar_type("float64" ) @@ -94,7 +112,7 @@ def test_dual_op(): t2 = sb.let("t2", relay.add(t1, x)) sb.ret(t2) f = relay.Function([x], sb.get()) - fchecked = relay.ir_pass.infer_type(f) + fchecked = run_infer_type(f) assert fchecked.checked_type == relay.FuncType([tp], tp) @@ -107,7 +125,7 @@ def @f(%x : Tensor[(10, 10), float32]) { tp = relay.TensorType((10, 10)) x = relay.var("x", tp) f = relay.Function([x], relay.log(x)) - fchecked = relay.ir_pass.infer_type(f) + fchecked = run_infer_type(f) assert fchecked.checked_type == relay.FuncType([tp], tp) @@ -145,7 +163,7 @@ def test_incomplete_call(): f = relay.var('f') func = relay.Function([x, f], relay.Call(f, [x]), tt) - ft = relay.ir_pass.infer_type(func) + ft = run_infer_type(func) f_type = relay.FuncType([tt], tt) assert ft.checked_type == relay.FuncType([tt, f_type], tt) @@ -164,7 +182,7 @@ def test_higher_order_argument(): # function even though id_func takes a type parameter ho_call = ho_func(id_func, relay.const(0, 'int32')) - hc = relay.ir_pass.infer_type(ho_call) + hc = run_infer_type(ho_call) expected = relay.scalar_type('int32') assert hc.checked_type == expected @@ -177,7 +195,7 @@ def test_higher_order_return(): b = relay.TypeVar('b') nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b]) - ft = relay.ir_pass.infer_type(nested_id) + ft = run_infer_type(nested_id) assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b]) @@ -198,7 +216,7 @@ def test_higher_order_nested(): [b]) expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b]) - ft = relay.ir_pass.infer_type(top) + ft = run_infer_type(top) assert ft.checked_type == expected @@ -206,8 +224,7 @@ def test_tuple(): tp = relay.TensorType((10,)) x = relay.var("x", tp) res = relay.Tuple([x, x]) - assert (relay.ir_pass.infer_type(res).checked_type == - relay.TupleType([tp, tp])) + assert (run_infer_type(res).checked_type == relay.TupleType([tp, tp])) def test_ref(): @@ -215,17 +232,17 @@ def test_ref(): y = relay.var("y", "float32") r = relay.RefCreate(x) st = relay.scalar_type("float32") - assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st) + assert run_infer_type(r).checked_type == relay.RefType(st) g = relay.RefRead(r) - assert relay.ir_pass.infer_type(g).checked_type == st + assert run_infer_type(g).checked_type == st w = relay.RefWrite(r, y) - assert relay.ir_pass.infer_type(w).checked_type == relay.TupleType([]) + assert run_infer_type(w).checked_type == relay.TupleType([]) def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) - yy = relay.ir_pass.infer_type(y) + yy = run_infer_type(y) assert yy.checked_type == relay.scalar_type("float32") assert x.vid.same_as(yy.args[0].vid) @@ -234,7 +251,7 @@ def test_type_args(): x = relay.var("x", shape=(10, 10)) y = relay.var("y", shape=(1, 10)) z = relay.add(x, y) - ty_z = relay.ir_pass.infer_type(z) + ty_z = run_infer_type(z) ty_args = ty_z.type_args assert len(ty_args) == 2 assert ty_args[0].dtype == "float32" @@ -256,15 +273,15 @@ def test_global_var_recursion(): func = relay.Function([x], relay.Call(gv, [x]), tt) mod[gv] = func - ft = relay.ir_pass.infer_type(gv, mod) - assert mod[ft].checked_type == relay.FuncType([tt], tt) + ft = run_infer_type(gv, mod) + assert ft.checked_type == relay.FuncType([tt], tt) def test_equal(): i = relay.var('i', shape=[], dtype='int32') eq = op.equal(i, relay.const(0, dtype='int32')) func = relay.Function([i], eq) - ft = relay.ir_pass.infer_type(func) + ft = run_infer_type(func) assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool')) @@ -275,8 +292,7 @@ def test_constructor_type(): a = relay.TypeVar('a') x = relay.Var('x', a) - ct = relay.ir_pass.infer_type( - relay.Function([x], constructor(x), box(a), [a]), mod) + ct = run_infer_type(relay.Function([x], constructor(x), box(a), [a]), mod) expected = relay.FuncType([a], box(a), [a]) assert ct.checked_type == expected @@ -288,8 +304,8 @@ def test_constructor_call(): box_unit = constructor(relay.Tuple([])) box_constant = constructor(relay.const(0, 'float32')) - ut = relay.ir_pass.infer_type(box_unit, mod) - ct = relay.ir_pass.infer_type(box_constant, mod) + ut = run_infer_type(box_unit, mod) + ct = run_infer_type(box_constant, mod) assert ut.checked_type == box(relay.TupleType([])) assert ct.checked_type == box(relay.TensorType((), 'float32')) @@ -308,7 +324,7 @@ def test_adt_match(): relay.Clause(relay.PatternWildcard(), relay.Tuple([]))]) - mt = relay.ir_pass.infer_type(match, mod) + mt = run_infer_type(match, mod) assert mt.checked_type == relay.TupleType([]) @@ -328,7 +344,7 @@ def test_adt_match_type_annotations(): relay.Tuple([]))]) func = relay.Function([x], match) - ft = relay.ir_pass.infer_type(func, mod) + ft = run_infer_type(func, mod) assert ft.checked_type == relay.FuncType([tt], relay.TupleType([])) diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 81f0222c029a7..655b5d794005a 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -26,7 +26,7 @@ def make_rel(name, args, num_inputs=None, attrs=None): return relay.ty.TypeRelation(func, args, num_inputs, attrs) def make_solver(): - solver = relay._ir_pass._test_type_solver() + solver = relay._analysis._test_type_solver() solver.Solve = solver("Solve") solver.Unify = solver("Unify") solver.Resolve = solver("Resolve") diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py index 4cb8f4f5d2ce5..963f2ac468465 100644 --- a/tests/python/relay/test_typecall.py +++ b/tests/python/relay/test_typecall.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from tvm import relay -from tvm.relay.ir_pass import infer_type +from tvm.relay import transform def test_dup_type(): a = relay.TypeVar("a") @@ -23,7 +23,10 @@ def test_dup_type(): make_id = relay.Function([av], relay.Tuple([av, av]), None, [a]) t = relay.scalar_type("float32") b = relay.Var("b", t) - assert relay.ir_pass.infer_type(make_id(b)).checked_type == relay.TupleType([t, t]) + mod = relay.Module.from_expr(make_id(b)) + mod = transform.InferType()(mod) + inferred = mod[mod.entry_func].body + assert inferred.checked_type == relay.TupleType([t, t]) def test_id_type(): @@ -36,7 +39,9 @@ def test_id_type(): make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b])) t = relay.scalar_type("float32") b = relay.Var("b", t) - assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t) + mod[mod.entry_func] = relay.Function([], make_id(b)) + mod = transform.InferType()(mod) + assert mod[mod.entry_func].body.checked_type == id_type(t) if __name__ == "__main__": diff --git a/tests/python/unittest/test_graph_tuner_core.py b/tests/python/unittest/test_graph_tuner_core.py index e0d2dc06c192a..1c3171944bc9d 100644 --- a/tests/python/unittest/test_graph_tuner_core.py +++ b/tests/python/unittest/test_graph_tuner_core.py @@ -43,7 +43,7 @@ def _create_data(target, dshape, dtype, layout): w2 = relay.var("w2_weight") conv2 = relay.nn.conv2d(conv1, w2, channels=32, kernel_size=(3, 3), padding=(1, 1)) out = relay.add(conv1, conv2) - net = relay.Function(relay.ir_pass.free_vars(out), out) + net = relay.Function(relay.analysis.free_vars(out), out) net, params = relay.testing.create_workload(net) tasks = autotvm.task.extract_from_program(net, target=target, diff --git a/tests/python/unittest/test_graph_tuner_utils.py b/tests/python/unittest/test_graph_tuner_utils.py index 0847166412d2c..5bbd1c4860c26 100644 --- a/tests/python/unittest/test_graph_tuner_utils.py +++ b/tests/python/unittest/test_graph_tuner_utils.py @@ -51,7 +51,7 @@ def test_has_multiple_inputs(): w0 = relay.var("w0") out2 = relay.nn.conv2d(data, w0) out = relay.add(out1, out2) - net = relay.Function(relay.ir_pass.free_vars(out), out) + net = relay.Function(relay.analysis.free_vars(out), out) net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)}) target_ops = ["conv2d"] node_list = [] @@ -80,7 +80,7 @@ def _count_node(node): op_name_list.append("Tuple") else: op_name_list.append("null") - relay.ir_pass.post_order_visit(net, _count_node) + relay.analysis.post_order_visit(net, _count_node) expr2graph(net, target_ops, node_dict, node_list) for i, item in enumerate(zip(op_name_list, node_list)): @@ -97,7 +97,7 @@ def test_get_direct_ancestor(): out3 = out2 + relay.expr.const(2.5) w1 = relay.var("w1") out = relay.nn.conv2d(out3, w1) - net = relay.Function(relay.ir_pass.free_vars(out), out) + net = relay.Function(relay.analysis.free_vars(out), out) net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)}) target_ops = ["conv2d"] node_list = [] @@ -117,7 +117,7 @@ def test_get_in_nodes(): out3 = out2 + relay.expr.const(2.5) w1 = relay.var("w1") out = relay.nn.conv2d(out3, w1) - net = relay.Function(relay.ir_pass.free_vars(out), out) + net = relay.Function(relay.analysis.free_vars(out), out) net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)}) target_ops = ["conv2d"] input_names = ["data"] diff --git a/tutorials/frontend/using_external_lib.py b/tutorials/frontend/using_external_lib.py index a33d4eb9dc7a5..35b015bffcd32 100644 --- a/tutorials/frontend/using_external_lib.py +++ b/tutorials/frontend/using_external_lib.py @@ -56,7 +56,7 @@ simple_net = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3,3), channels=out_channels, padding=(1, 1)) simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] simple_net = relay.nn.relu(simple_net) -simple_net = relay.Function(relay.ir_pass.free_vars(simple_net), simple_net) +simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net) data_shape = (batch_size, 3, 224, 224) net, params = testing.create_workload(simple_net) diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 6f901833ea159..f7d7be8c80477 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -18,9 +18,17 @@ """A Relay implementation of graph packing.""" from tvm import relay -from tvm.relay import op +from tvm.relay import op, transform from tvm.relay import ExprMutator +def run_opt_pass(expr, opt_pass): + """Exectue a relay pass.""" + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + def _to_shape(shape): return tuple(int(sh) for sh in shape) @@ -231,7 +239,7 @@ def get_subgraph(expr, start_name, stop_name): """ bitpack_start = op.op.get('annotation.bitpack_start') bitpack_end = op.op.get('annotation.bitpack_end') - anf = relay.ir_pass.to_a_normal_form(expr) + anf = run_opt_pass(expr, transform.ToANormalForm()) def _recursion(anf, start_found, stop_found): """ Helper to obtain the subgraph. """ @@ -262,7 +270,7 @@ def _recursion(anf, start_found, stop_found): assert stop_found return anf annotated = _recursion(anf, False, False) - return relay.ir_pass.infer_type(relay.ir_pass.to_graph_normal_form(annotated)) + return run_opt_pass(annotated, transform.ToGraphNormalForm()) def graph_pack(expr, bfactor, @@ -299,10 +307,10 @@ def graph_pack(expr, """ assert isinstance(expr, relay.Function) expr = get_subgraph(expr, start_name, stop_name) - expr = relay.ir_pass.infer_type(expr) + expr = run_opt_pass(expr, transform.InferType()) packer = ExprPack( bfactor, cfactor, weight_bits) expr = packer.visit(expr) assert not packer.start_pack - return relay.ir_pass.infer_type(expr) + return run_opt_pass(expr, transform.InferType()) diff --git a/vta/scripts/tune_resnet.py b/vta/scripts/tune_resnet.py index 21aa96cd350fc..43bc6acc15d59 100644 --- a/vta/scripts/tune_resnet.py +++ b/vta/scripts/tune_resnet.py @@ -139,7 +139,6 @@ def compile_network(opt, env, target): env.WGT_WIDTH, start_name=opt.start_name, stop_name=opt.stop_name) - relay_prog = relay.ir_pass.fold_constant(relay_prog) return relay_prog, params diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py index bdeb6c5d03e2c..9f734bc65d929 100644 --- a/vta/tutorials/autotvm/tune_relay_vta.py +++ b/vta/tutorials/autotvm/tune_relay_vta.py @@ -103,7 +103,6 @@ def compile_network(env, target, model, start_pack, stop_pack): env.WGT_WIDTH, start_name=start_pack, stop_name=stop_pack) - relay_prog = relay.ir_pass.fold_constant(relay_prog) return relay_prog, params diff --git a/vta/tutorials/frontend/deploy_resnet_on_vta.py b/vta/tutorials/frontend/deploy_resnet_on_vta.py index 271630e695588..3e252172444a0 100644 --- a/vta/tutorials/frontend/deploy_resnet_on_vta.py +++ b/vta/tutorials/frontend/deploy_resnet_on_vta.py @@ -172,7 +172,6 @@ env.WGT_WIDTH, start_name=start_pack, stop_name=stop_pack) - relay_prog = relay.ir_pass.fold_constant(relay_prog) # Compile Relay program with AlterOpLayout disabled with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): From 8f13a5ba4118e9453a75ba6b69abdf3f09780229 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 2 Jul 2019 09:17:57 -0700 Subject: [PATCH 09/26] Delete _ir_pass.pyi --- python/tvm/relay/_ir_pass.pyi | 26 -------------------------- 1 file changed, 26 deletions(-) delete mode 100644 python/tvm/relay/_ir_pass.pyi diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi deleted file mode 100644 index 13035bb36f716..0000000000000 --- a/python/tvm/relay/_ir_pass.pyi +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -from . import ir -from .env import Module - -def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ... -def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ... -def _get_checked_type(expr: ir.Expr) -> ir.Type: ... -def well_formed(expr: ir.Expr) -> bool: ... -def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ... From 988ea2ac78e80826b0f3d48c96f3287421b80762 Mon Sep 17 00:00:00 2001 From: abergeron Date: Tue, 2 Jul 2019 12:23:14 -0400 Subject: [PATCH 10/26] Add dockerfiles for the conda package builds (#3344) * First shot * Add dockerfile for CPU too * Finish the build infrastructure * Remove extra file * Comment out the Jenkinsfile section since it is not ready * Add missing license headers * Update to newer cudnn that anaconda packaged * Bump the build numbers for the newer cudnn * Bring back the toolchain option with a tweak for cuda * Cache some large packages in the docker and update to llvm 7.0.0 * Merge all the python packages together * First fix for the conda cuda builds (again) * Use the tarball version of cudnn since tvm has trouble detecting the other one * Use llvm 8.0 from the numba packages * Also use llvm 8.0 for the cpu builds * Don't use the anaconda compiler for OS X * Enable Metal on OS X builds * Make sure to detect undefined variables in scripts * Fix build when not using cuda --- Jenkinsfile | 18 +++++++ conda/Dockerfile.template | 15 ++++-- conda/Makefile | 22 -------- conda/{nnvm/build.sh => build_cpu.sh} | 19 +++++-- conda/{topi/build.sh => build_cuda.sh} | 18 +++++-- conda/{topi/meta.yaml => cross-linux.cmake} | 45 ++++++----------- conda/nnvm/meta.yaml | 56 --------------------- conda/{build_cuda.py => render_cuda.py} | 23 ++------- conda/tvm-libs/build.sh | 29 ++++++++--- conda/tvm-libs/meta.yaml | 19 +++---- conda/tvm/build.sh | 10 ++++ conda/tvm/meta.yaml | 13 ++++- docker/Dockerfile.conda_cpu | 41 +++++++++++++++ docker/Dockerfile.conda_cuda100 | 46 +++++++++++++++++ docker/Dockerfile.conda_cuda90 | 46 +++++++++++++++++ version.py | 4 +- 16 files changed, 262 insertions(+), 162 deletions(-) delete mode 100644 conda/Makefile rename conda/{nnvm/build.sh => build_cpu.sh} (68%) mode change 100644 => 100755 rename conda/{topi/build.sh => build_cuda.sh} (70%) mode change 100644 => 100755 rename conda/{topi/meta.yaml => cross-linux.cmake} (54%) delete mode 100644 conda/nnvm/meta.yaml rename conda/{build_cuda.py => render_cuda.py} (74%) create mode 100644 docker/Dockerfile.conda_cpu create mode 100644 docker/Dockerfile.conda_cuda100 create mode 100644 docker/Dockerfile.conda_cuda90 diff --git a/Jenkinsfile b/Jenkinsfile index 53645eb14b280..c38ec5296bf35 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -309,6 +309,24 @@ stage('Integration Test') { } } +/* +stage('Build packages') { + parallel 'conda CPU': { + node('CPU') { + sh "${docker_run} tvmai/conda-cpu ./conda/build_cpu.sh + } + }, + 'conda cuda': { + node('CPU') { + sh "${docker_run} tvmai/conda-cuda90 ./conda/build_cuda.sh + sh "${docker_run} tvmai/conda-cuda100 ./conda/build_cuda.sh + } + } + // Here we could upload the packages to anaconda for releases + // and/or the master branch +} +*/ + stage('Deploy') { node('doc') { ws('workspace/tvm/deploy-docs') { diff --git a/conda/Dockerfile.template b/conda/Dockerfile.template index 59b9ac96814ee..1b5dc6fbef5e0 100644 --- a/conda/Dockerfile.template +++ b/conda/Dockerfile.template @@ -15,9 +15,13 @@ # specific language governing permissions and limitations # under the License. -FROM nvidia/cuda:{{ cuda_version }}-devel-centos6 +FROM nvidia/cuda:{{ cuda_version }}-devel-ubuntu16.04 -RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v{{ cudnn_short_version }}/cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -O && \ +RUN apt-get update && apt-get install -y --no-install-recommends \ + bzip2 curl sudo binutils && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v{{ cudnn_short_version }}/cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -O && \ tar --no-same-owner -xzf cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -C /usr/local && \ rm cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz && \ ldconfig @@ -27,13 +31,16 @@ RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-lat chmod +x ~/miniconda.sh && \ ~/miniconda.sh -b -p /opt/conda && \ rm ~/miniconda.sh && \ + /opt/conda/bin/conda upgrade --all && \ /opt/conda/bin/conda install conda-build conda-verify && \ /opt/conda/bin/conda clean -ya +RUN /opt/conda/bin/conda install --download-only cmake make zlib +RUN /opt/conda/bin/conda install --download-only -c numba llvmdev=8.0.0 + ENV PATH /opt/conda/bin:$PATH ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV CONDA_BLD_PATH /tmp WORKDIR /workspace RUN chmod -R a+w /workspace - -CMD conda build --output-folder /workspace/conda/pkg --variants '{cuda: True, cuda_version: {{ cuda_version }}}' /workspace/conda/tvm-libs diff --git a/conda/Makefile b/conda/Makefile deleted file mode 100644 index cda546ac73ce3..0000000000000 --- a/conda/Makefile +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -packages: - conda build tvm-libs - conda build tvm - conda build topi - conda built nnvm diff --git a/conda/nnvm/build.sh b/conda/build_cpu.sh old mode 100644 new mode 100755 similarity index 68% rename from conda/nnvm/build.sh rename to conda/build_cpu.sh index bdd333f57734c..992b1a369b96b --- a/conda/nnvm/build.sh +++ b/conda/build_cpu.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/sh # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,6 +17,15 @@ # under the License. set -e +set -u -cd nnvm/python -$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +# This is a fix for a weird bug in conda that makes it think +# it can't write in /tmp +HOME=/tmp +mkdir -p /tmp/.conda/pkgs +touch /tmp/.conda/pkgs/urls.txt +touch /tmp/.conda/environments.txt + + +conda build --output-folder=conda/pkg -c numba conda/tvm-libs +conda build --output-folder=conda/pkg -m conda/conda_build_config.yaml conda/tvm diff --git a/conda/topi/build.sh b/conda/build_cuda.sh old mode 100644 new mode 100755 similarity index 70% rename from conda/topi/build.sh rename to conda/build_cuda.sh index 4e5aafb937660..2c9a20ae66aec --- a/conda/topi/build.sh +++ b/conda/build_cuda.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/sh # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,6 +17,14 @@ # under the License. set -e +set -u -cd topi/python -$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +# This is a fix for a weird bug in conda that makes it think +# it can't write in /tmp +HOME=/tmp +mkdir -p /tmp/.conda/pkgs +touch /tmp/.conda/pkgs/urls.txt +touch /tmp/.conda/environments.txt + + +conda build --output-folder=conda/pkg --variants "{cuda: True, cuda_version: ${CUDA_VERSION%.*}}" -c numba conda/tvm-libs diff --git a/conda/topi/meta.yaml b/conda/cross-linux.cmake similarity index 54% rename from conda/topi/meta.yaml rename to conda/cross-linux.cmake index f4bc8950d4c49..360400267ae07 100644 --- a/conda/topi/meta.yaml +++ b/conda/cross-linux.cmake @@ -15,37 +15,24 @@ # specific language governing permissions and limitations # under the License. -{% set version = "0.6.dev" %} +# this one is important +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_PLATFORM Linux) +#this one not so much +set(CMAKE_SYSTEM_VERSION 1) -package: - name: topi - version: {{ version }} +# specify the cross compiler +set(CMAKE_C_COMPILER $ENV{CC}) -source: - path: ../.. +# where is the target environment +set(CMAKE_FIND_ROOT_PATH $ENV{PREFIX} $ENV{BUILD_PREFIX}/$ENV{HOST}/sysroot) -build: - number: 1 +# search for programs in the build host directories +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) -requirements: - host: - - python {{ python }} - - numpy - - setuptools - - decorator - - tvm-libs =={{ version }} - run: - - python - - {{ pin_compatible('numpy') }} - - decorator - - tvm-libs =={{ version }} - - tvm =={{ version }} +# for libraries and headers in the target directories +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) -test: - imports: - - topi - -about: - home: https://github.com/dmlc/tvm - license: Apache2 - summary: "TOPI: TVM Operator Inventory" +# god-awful hack because it seems to not run correct tests to determine this: +set(__CHAR_UNSIGNED___EXITCODE 1) diff --git a/conda/nnvm/meta.yaml b/conda/nnvm/meta.yaml deleted file mode 100644 index d948484a61e5f..0000000000000 --- a/conda/nnvm/meta.yaml +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -{% set version = "0.6.dev" %} - -package: - name: nnvm - version: {{ version }} - -source: - path: ../.. - -build: - number: 1 - skip: True # [win] - -requirements: - build: - - {{ compiler('cxx') }} - host: - - python {{ python }} - - cython - - numpy - - setuptools - - decorator - - tvm-libs =={{ version }} - run: - - tvm =={{ version }} - - topi =={{ version }} - - tvm-libs =={{ version }} - - python - - {{ pin_compatible('numpy') }} - - decorator - -test: - imports: - - nnvm - -about: - home: https://github.com/dmlc/nnvm - license: Apache2 - summary: Bring deep learning to bare metal diff --git a/conda/build_cuda.py b/conda/render_cuda.py similarity index 74% rename from conda/build_cuda.py rename to conda/render_cuda.py index 47af6ce4564e9..8057892fd83c1 100644 --- a/conda/build_cuda.py +++ b/conda/render_cuda.py @@ -29,8 +29,8 @@ # and from conda. # These two must be in sync -CUDNN_FULL_VERSION = '7.3.1.20' -CUDNN_VERSION = '7.3.1' +CUDNN_FULL_VERSION = '7.6.0.64' +CUDNN_VERSION = '7.6.0' condadir = os.path.dirname(sys.argv[0]) @@ -47,30 +47,15 @@ def render_dockerfile(version): cudnn_short_version=CUDNN_VERSION, cudnn_version=CUDNN_FULL_VERSION) fname = os.path.join(condadir, - 'Dockerfile.cuda' + version.replace('.', '')) + '../docker/Dockerfile.conda_cuda' + version.replace('.', '')) with open(fname, 'w') as f: f.write(txt) return fname -def build_docker(version): - vv = version.replace('.', '') - fname = render_dockerfile(version) - tagname = f'tvm-cuda{ vv }-forge' - subprocess.run(['docker', 'build', '-t', tagname, - condadir, '-f', fname], check=True) - return tagname - - -def build_pkg(version): - tagname = build_docker(version) - subprocess.run(['docker', 'run', '--rm', '-v', f'{ srcdir }:/workspace', - tagname], check=True) - - if __name__ == '__main__': build_versions = CUDA_VERSIONS if len(sys.argv) > 1: build_versions = sys.argv[1:] for version in build_versions: - build_pkg(version) + render_dockerfile(version) diff --git a/conda/tvm-libs/build.sh b/conda/tvm-libs/build.sh index e0b85910475ea..94919c60e7797 100644 --- a/conda/tvm-libs/build.sh +++ b/conda/tvm-libs/build.sh @@ -17,24 +17,37 @@ # under the License. set -e - -if [ "$cuda" == "True" ]; then - CUDA_OPT="-DUSE_CUDA=ON -DUSE_CUBLAS=ON -DUSE_CUDNN=ON" -else - CUDA_OPT="" -fi +set -u if [ "$target_platform" == "osx-64" ]; then # macOS 64 bits - METAL_OPT="" # Conda can only target 10.9 for now + METAL_OPT="-DUSE_METAL=ON" + TOOLCHAIN_OPT="-DCMAKE_OSX_DEPLOYMENT_TARGET=10.11" else METAL_OPT="" + if [ "$target_platform" == "linux-64" ]; then + # Linux 64 bits + TOOLCHAIN_OPT="-DCMAKE_TOOLCHAIN_FILE=${RECIPE_DIR}/../cross-linux.cmake" + else + # Windows (or 32 bits, which we don't support) + TOOLCHAIN_OPT="" + fi +fi + +# When cuda is not set, we default to False +cuda=${cuda:-False} + +if [ "$cuda" == "True" ]; then + CUDA_OPT="-DUSE_CUDA=ON -DUSE_CUBLAS=ON -DUSE_CUDNN=ON" + TOOLCHAIN_OPT="" +else + CUDA_OPT="" fi rm -rf build || true mkdir -p build cd build -cmake $METAL_OPT $CUDA_OPT -DUSE_LLVM=$PREFIX/bin/llvm-config -DINSTALL_DEV=ON -DCMAKE_INSTALL_PREFIX="$PREFIX" .. +cmake $METAL_OPT $CUDA_OPT -DUSE_LLVM=$PREFIX/bin/llvm-config -DINSTALL_DEV=ON -DCMAKE_INSTALL_PREFIX="$PREFIX" $TOOLCHAIN_OPT .. make -j${CPU_COUNT} VERBOSE=1 make install cd .. diff --git a/conda/tvm-libs/meta.yaml b/conda/tvm-libs/meta.yaml index aad8f251c2a69..e3422a2174efe 100644 --- a/conda/tvm-libs/meta.yaml +++ b/conda/tvm-libs/meta.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -{% set version = "0.6.dev" %} +{% set version = "0.6.dev1" %} package: name: tvm-libs @@ -25,21 +25,22 @@ source: path: ../.. build: - number: 1 - string: cuda{{ cuda_version }}_{{ PKG_BUILDNUM }} # [cuda] + number: 0 + string: cuda{{ cuda_version | replace('.', '') }}h{{ PKG_HASH }}_{{ PKG_BUILDNUM }} # [cuda] requirements: build: - # The OS X build will require some manual setup or it will break - # See https://docs.conda.io/projects/conda-build/en/latest/source/resources/compiler-tools.html#macos-sdk - - {{ compiler('cxx') }} - host: + # The anaconda compilers for OS X are old an annoying + # so we rely on the platform ones for now + - {{ compiler('cxx') }} # [linux] - cmake - - llvmdev ==6.0.0 + - make + host: + - llvmdev ==8.0.0 - zlib # [linux] run: - {{ pin_compatible('cudatoolkit', lower_bound=cuda_version, max_pin='x.x') }} # [cuda] - - {{ pin_compatible('cudnn', lower_bound='7.3.1', max_pin='x') }} # [cuda] + - {{ pin_compatible('cudnn', lower_bound='7.6.0', max_pin='x') }} # [cuda] about: home: https://github.com/dmlc/tvm diff --git a/conda/tvm/build.sh b/conda/tvm/build.sh index 6626aa5920914..494f90f0afa01 100644 --- a/conda/tvm/build.sh +++ b/conda/tvm/build.sh @@ -17,6 +17,16 @@ # under the License. set -e +set -u cd python $PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +cd .. + +cd topi/python +$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +cd ../.. + +cd nnvm/python +$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +cd ../.. diff --git a/conda/tvm/meta.yaml b/conda/tvm/meta.yaml index 221dc7950f753..0daca4bcea2bd 100644 --- a/conda/tvm/meta.yaml +++ b/conda/tvm/meta.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -{% set version = "0.6.dev" %} +{% set version = "0.6.dev1" %} package: name: tvm @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 1 + number: 0 requirements: build: @@ -46,6 +46,15 @@ requirements: test: imports: - tvm + - topi + - nnvm + requires: + - nose + - scipy + source_files: + - tests/python + commands: + - python -m nose -v tests/python/integration about: home: https://github.com/dmlc/tvm diff --git a/docker/Dockerfile.conda_cpu b/docker/Dockerfile.conda_cpu new file mode 100644 index 0000000000000..0660b5daa0e26 --- /dev/null +++ b/docker/Dockerfile.conda_cpu @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM ubuntu:16.04 + +RUN apt-get update && apt-get install -y bzip2 curl sudo binutils && rm -rf /var/lib/apt/lists/* + +RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b -p /opt/conda && \ + rm ~/miniconda.sh && \ + /opt/conda/bin/conda upgrade --all && \ + /opt/conda/bin/conda install conda-build conda-verify && \ + /opt/conda/bin/conda clean -ya + +# Cache some of the packages for the builds +RUN /opt/conda/bin/conda install --download-only cmake make zlib && \ + /opt/conda/bin/conda install --download-only -c numba llvmdev=8.0.0 && \ + /opt/conda/bin/conda create -n py35 --download-only nose scipy numpy=1.11 cython decorator python=3.5 && \ + /opt/conda/bin/conda create -n py36 --download-only nose scipy numpy=1.11 cython decorator python=3.6 && \ + /opt/conda/bin/conda create -n py37 --download-only nose scipy numpy=1.11 cython decorator python=3.7 + +ENV PATH /opt/conda/bin:$PATH +ENV CONDA_BLD_PATH /tmp + +WORKDIR /workspace +RUN chmod -R a+w /workspace diff --git a/docker/Dockerfile.conda_cuda100 b/docker/Dockerfile.conda_cuda100 new file mode 100644 index 0000000000000..d6e1cddbfd373 --- /dev/null +++ b/docker/Dockerfile.conda_cuda100 @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM nvidia/cuda:10.0-devel-ubuntu16.04 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + bzip2 curl sudo binutils && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.0/cudnn-10.0-linux-x64-v7.6.0.64.tgz -O && \ + tar --no-same-owner -xzf cudnn-10.0-linux-x64-v7.6.0.64.tgz -C /usr/local && \ + rm cudnn-10.0-linux-x64-v7.6.0.64.tgz && \ + ldconfig + + +RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b -p /opt/conda && \ + rm ~/miniconda.sh && \ + /opt/conda/bin/conda upgrade --all && \ + /opt/conda/bin/conda install conda-build conda-verify && \ + /opt/conda/bin/conda clean -ya + +RUN /opt/conda/bin/conda install --download-only cmake make zlib +RUN /opt/conda/bin/conda install --download-only -c numba llvmdev=8.0.0 + +ENV PATH /opt/conda/bin:$PATH +ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV CONDA_BLD_PATH /tmp + +WORKDIR /workspace +RUN chmod -R a+w /workspace \ No newline at end of file diff --git a/docker/Dockerfile.conda_cuda90 b/docker/Dockerfile.conda_cuda90 new file mode 100644 index 0000000000000..f55aa1bf2e126 --- /dev/null +++ b/docker/Dockerfile.conda_cuda90 @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM nvidia/cuda:9.0-devel-ubuntu16.04 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + bzip2 curl sudo binutils && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.0/cudnn-9.0-linux-x64-v7.6.0.64.tgz -O && \ + tar --no-same-owner -xzf cudnn-9.0-linux-x64-v7.6.0.64.tgz -C /usr/local && \ + rm cudnn-9.0-linux-x64-v7.6.0.64.tgz && \ + ldconfig + + +RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b -p /opt/conda && \ + rm ~/miniconda.sh && \ + /opt/conda/bin/conda upgrade --all && \ + /opt/conda/bin/conda install conda-build conda-verify && \ + /opt/conda/bin/conda clean -ya + +RUN /opt/conda/bin/conda install --download-only cmake make zlib +RUN /opt/conda/bin/conda install --download-only -c numba llvmdev=8.0.0 + +ENV PATH /opt/conda/bin:$PATH +ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV CONDA_BLD_PATH /tmp + +WORKDIR /workspace +RUN chmod -R a+w /workspace \ No newline at end of file diff --git a/version.py b/version.py index 1df897a40fb4f..c949d5c2ead2f 100644 --- a/version.py +++ b/version.py @@ -24,8 +24,6 @@ - tvm-root/include/tvm/runtime/c_runtime_api.h - tvm-root/web/tvm_runtime.js - tvm-root/conda/tvm/meta.yaml -- tvm-root/conda/topi/meta.yaml -- tvm-root/conda/nnvm/meta.yaml - tvm-root/conda/tvm-libs/meta.yaml """ import os @@ -71,7 +69,7 @@ def main(): update(os.path.join(proj_root, "include", "tvm", "runtime", "c_runtime_api.h"), "(?<=TVM_VERSION \")[.0-9a-z]+", __version__) # conda - for path in ["tvm", "topi", "nnvm", "tvm-libs"]: + for path in ["tvm", "tvm-libs"]: update(os.path.join(proj_root, "conda", path, "meta.yaml"), "(?<=version = \")[.0-9a-z]+", __version__) # web From d1eb1229f6da900b8bd0a8bdafc839106cf243f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Tue, 2 Jul 2019 19:45:23 -0700 Subject: [PATCH 11/26] [Relay] Continuation Passing Style (#3456) * save add me find type checker problem save save lint do lint reset ti add some doc add failed test case add recursion for cps add recursion for cps fix pytest lint save fix test error lint save fix error * fix rebase * fix * fix test * lint * lint * restore rewriteannotationops * do --- include/tvm/relay/analysis.h | 11 - include/tvm/relay/transform.h | 67 +++ python/tvm/relay/testing/__init__.py | 15 + python/tvm/relay/transform.py | 64 ++- src/relay/ir/adt.cc | 6 +- src/relay/ir/module.cc | 3 +- src/relay/ir/pretty_printer.cc | 12 +- src/relay/ir/type_functor.cc | 2 +- src/relay/pass/de_duplicate.cc | 122 ++++++ src/relay/pass/dependency_graph.h | 2 +- src/relay/pass/let_list.h | 25 +- src/relay/pass/partial_eval.cc | 84 +--- src/relay/pass/to_a_normal_form.cc | 40 +- src/relay/pass/to_cps.cc | 397 ++++++++++++++++++ src/relay/pass/type_infer.cc | 14 +- tests/python/relay/test_pass_fuse_ops.py | 9 +- tests/python/relay/test_pass_gradient.py | 10 +- .../relay/test_pass_to_a_normal_form.py | 14 + tests/python/relay/test_pass_to_cps.py | 100 +++++ 19 files changed, 840 insertions(+), 157 deletions(-) create mode 100644 src/relay/pass/de_duplicate.cc create mode 100644 src/relay/pass/to_cps.cc create mode 100644 tests/python/relay/test_pass_to_cps.py diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index e3d16b6eda739..deb9c7dec0c56 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -251,17 +251,6 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); -/*! - * \brief Rewrite the annotated program. - * - * \param expr The expression. - * \param fallback_device The fallback device which is the default device for - * operators without annotation. - * - * \return The updated program. - */ -TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); - /*! * \brief Collect the device mapping information of each expression. * diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index bb8638abbabf1..93129cf57a279 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -404,6 +404,22 @@ TVM_DLL Pass RewriteAnnotatedOps(int fallback_device); */ TVM_DLL Pass ToANormalForm(); +/*! + * \brief Turn an expression into continuation passing style(CPS). + * + * CPS mean that every function will, instead of returning the result directly, + * be passed down an extra function (called the continuation) as argument, + * and pass the result to the continuation instead. + * + * Thus, every function call has to be passed an extra argument + * that represent the rest of the computation (Hence the name of continuation). + * + * Similarly, all other compute will be wrapped and call the continuation as well. + * + * \return the pass. + */ +TVM_DLL Pass ToCPS(); + /*! * \brief Remove let binding and directly share via pointer instead. * @@ -586,6 +602,57 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); +/*! + * \brief Rewrite the annotated program. + * + * \param expr The expression. + * \param fallback_device The fallback device which is the default device for + * operators without annotation. + * + * \return The updated program. + */ +TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); + +/*! + * \brief Turn an expression into continuation passing style(CPS). + * + * CPS mean that every function will, instead of returning the result directly, + * be passed down an extra function (called the continuation) as argument, + * and pass the result to the continuation instead. + * + * Thus, every function call has to be passed an extra argument + * that represent the rest of the computation (Hence the name of continuation). + * + * Similarly, all other compute will be wrapped and call the continuation as well. + * + * \param f the function. + * \param mod the module. + * + * \return the converted Function. + */ +TVM_DLL Function ToCPS(const Function& f, const Module& mod); + +/*! + * \brief Remove the continuation argument of a CPS function. + * + * Note that this only transform the type back into un-CPS form + * when there is no higher order input/output. + * + * \param f the function. + * + * \return the converted Function. + */ +TVM_DLL Function UnCPS(const Function& f); + +/*! + * \brief Deduplicate the bound variables and type variables in the expression. + * + * \param e the expression. + * + * \return the deduplicated expression. + */ +TVM_DLL Expr DeDup(const Expr& e); + } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 7a5007bbfb8f2..9d12529e576f3 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -17,6 +17,9 @@ """Utilities for testing and benchmarks""" from __future__ import absolute_import as _abs +import tvm.relay as relay +from tvm.relay import transform + from . import mlp from . import resnet from . import dqn @@ -32,3 +35,15 @@ from .config import ctx_list from .init import create_workload from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr + + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod[mod.entry_func] + return entry if isinstance(expr, relay.Function) else entry.body + + +def run_infer_type(expr): + return run_opt_pass(expr, transform.InferType()) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 255718c627f0e..f77a532ba7387 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -446,6 +446,20 @@ def ToANormalForm(): return _transform.ToANormalForm() +def ToCPS(expr, mod=None): + """ + Turn expression into continuation passing style(CPS). + + Every intermediate compute will be passed to a continuation. + + Returns + ------- + result: tvm.relay.Pass + The registered pass that transforms an expression into CPS. + """ + return _ir_pass.to_cps(expr, mod) + + def EtaExpand(): """Add abstraction over a function @@ -495,14 +509,6 @@ def PartialEvaluate(): expression is provided. Otherwise, it will rely on the pass manager to carry out transformation. - Parameters - ---------- - expr : Optional[tvm.relay.Expr] - The input expression. - - mod : Optional[tvm.relay.Module] - The global module. - Returns ------- ret: tvm.relay.Pass @@ -554,6 +560,48 @@ def gradient(expr, mod=None, mode='higher_order'): raise Exception('unknown mode') +def to_cps(func, mod=None): + """ + Turn expression into CPS expression. + + Every intermediate compute will be passed to a continuation. + + Parameters + ---------- + func: tvm.relay.Function + The input function. + + mod: Optional[tvm.relay.Module] + The global module. + + Returns + ------- + result: tvm.relay.Function + The output function. + """ + return _transform.to_cps(func, mod) + + +def un_cps(func): + """ + Turn an cps function into a Function without the continuation argument. + + Note that this will not give the exact same interface as before cps: + If the input/output is higher order, they will still be in cps form. + + Parameters + ---------- + func: tvm.relay.Function + The input function + + Returns + ------- + result: tvm.relay.Function + The output function + """ + return _transform.un_cps(func) + + def _wrap_class_module_pass(pass_cls, pass_info): """Wrap a python class as function pass""" class PyModulePass(ModulePass): diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index b59281a4f1fd9..3eb1d99f5a889 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file src/tvm/ir/adt.cc * \brief AST nodes for Relay algebraic data types (ADTs). */ diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 51a2aeeeb111f..4286be293b428 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -89,8 +89,9 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { } void ModuleNode::Add(const GlobalVar& var, - const Function& func, + const Function& f, bool update) { + Function func = Downcast(DeDup(f)); // Type check the item before we add it to the module. auto mod = GetRef(this); Function checked_func = InferType(func, mod, var); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 7a61079204edc..39fc36fba4baf 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -645,11 +645,21 @@ class PrettyPrinter : Doc VisitType_(const FuncTypeNode* node) final { Doc doc; + doc << "fn "; + if (node->type_params.size() != 0) { + doc << "<"; + std::vector type_params; + for (Type type_param : node->type_params) { + type_params.push_back(Print(type_param)); + } + doc << PrintVec(type_params); + doc << ">"; + } std::vector arg_types; for (Type arg_type : node->arg_types) { arg_types.push_back(Print(arg_type)); } - return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); + return doc << "(" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); } Doc VisitType_(const RefTypeNode* node) final { diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 9fca2e0326859..516f4c875b20c 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -221,7 +221,7 @@ class TypeBinder : public TypeMutator { }; Type Bind(const Type& type, const tvm::Map& args_map) { - return TypeBinder(args_map).VisitType(type); + return type.defined() ? TypeBinder(args_map).VisitType(type) : type; } } // namespace relay diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc new file mode 100644 index 0000000000000..d5d4f69606539 --- /dev/null +++ b/src/relay/pass/de_duplicate.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file de_duplicate.cc + * \brief Use a fresh Id for every Var to make the result well-formed. + */ + +#include +#include +#include +#include "../ir/type_functor.h" + +namespace tvm { +namespace relay { + +Expr DeDup(const Expr& e) { + class DeDupMutator : public TypeMutator, + public ExprMutator, + public PatternMutator { + public: + TypeVar Fresh(const TypeVar& tv) { + TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind); + type_rename_[tv] = ret; + return ret; + } + + Var Fresh(const Var& v) { + Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation)); + rename_[v] = ret; + return ret; + } + + Expr VisitExpr(const Expr& e) final { + return ExprMutator::VisitExpr(e); + } + + Expr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + return rename_.count(v) != 0 ? rename_.at(v) : v; + } + + Expr VisitExpr_(const LetNode* op) final { + Var v = Fresh(op->var); + return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); + } + + Type VisitType(const Type& t) final { + return t.defined() ? TypeMutator::VisitType(t) : t; + } + + Expr VisitExpr_(const FunctionNode* op) final { + tvm::Array type_params; + for (const TypeVar& type_param : op->type_params) { + type_params.push_back(Fresh(type_param)); + } + tvm::Array params; + for (const Var& param : op->params) { + params.push_back(Fresh(param)); + } + return FunctionNode::make(params, + VisitExpr(op->body), + VisitType(op->ret_type), + type_params, + op->attrs); + } + + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Pattern VisitPattern_(const PatternVarNode* op) final { + return PatternVarNode::make(Fresh(op->var)); + } + + Clause VisitClause(const Clause& c) final { + Pattern pat = VisitPattern(c->lhs); + return ClauseNode::make(pat, VisitExpr(c->rhs)); + } + + Type VisitType_(const TypeVarNode* op) final { + TypeVar v = GetRef(op); + return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; + } + + Var VisitVar(const Var& v) final { + return Fresh(v); + } + + private: + std::unordered_map rename_; + std::unordered_map type_rename_; + }; + + Expr ret = DeDupMutator().VisitExpr(e); + CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size()); + return ret; +} + +TVM_REGISTER_API("relay._transform.dedup") +.set_body_typed(DeDup); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/dependency_graph.h b/src/relay/pass/dependency_graph.h index 7f53918ebcb7f..5e2b08c352f09 100644 --- a/src/relay/pass/dependency_graph.h +++ b/src/relay/pass/dependency_graph.h @@ -20,7 +20,7 @@ /*! * Copyright (c) 2019 by Contributors. * \file tvm/relay/pass/dependency_graph.h - * \brief + * \brief create a dependency graph. */ #ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ #define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 9f56b22fc13e9..1b422d2a878f0 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file let_list.h * \brief LetList record let binding and insert let expression implicitly. * using it, one can treat AST as value instead of expression, @@ -46,6 +46,11 @@ namespace relay { */ class LetList { public: + ~LetList() { + if (lets_.size() > 0 && !used_) { + std::cout << "Warning: letlist not used" << std::endl; + } + } /*! * \brief insert a binding. * @@ -64,13 +69,13 @@ class LetList { /*! * \brief insert a binding. * - * \param ty the type of the binding. - * * \param expr the value of the binding. * + * \param ty the type of the binding. + * * \return a Var that hold the inserted expr. */ - Var Push(Type ty, Expr expr) { + Var Push(Expr expr, Type ty) { return Push(VarNode::make("x", ty), expr); } @@ -82,7 +87,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Expr expr) { - return Push(Type(), expr); + return Push(expr, Type()); } /*! @@ -129,6 +134,12 @@ class LetList { return ll.Get(f(&ll)); } + static Expr Let(const Expr& e, const std::function& f) { + return With([&](LetList* ll) { + return f(ll->Push(e)); + }); + } + private: std::vector > lets_; bool used_ = false; diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index acc60982cff44..6887c7a603227 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * * \file partial_eval.cc * @@ -426,8 +426,6 @@ TVM_ADD_FILELINE) Expr StripWithFuncId(const Expr& e); -Expr DeDup(const Expr& e); - Function AsFunc(const Expr& e) { if (e.as()) { return Downcast(e); @@ -963,86 +961,6 @@ class PartialEvaluator : public ExprFunctor FInterpreter executor_ = CPUInterpreter(); }; -/*! \brief Use a fresh Id for every Var to make the result well-formed. */ -Expr DeDup(const Expr& e) { - class DeDupMutator : public TypeMutator, - public ExprMutator, - public PatternMutator { - public: - TypeVar Fresh(const TypeVar& tv) { - TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind); - type_rename_[tv] = ret; - return ret; - } - - Var Fresh(const Var& v) { - Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation)); - rename_[v] = ret; - return ret; - } - - Expr VisitExpr(const Expr& e) final { - return ExprMutator::VisitExpr(e); - } - - Expr VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); - return rename_.count(v) != 0 ? rename_.at(v) : v; - } - - Expr VisitExpr_(const LetNode* op) final { - Var v = Fresh(op->var); - return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); - } - - Type VisitType(const Type& t) final { - return t.defined() ? TypeMutator::VisitType(t) : t; - } - - Expr VisitExpr_(const FunctionNode* op) final { - tvm::Array type_params; - for (const TypeVar& type_param : op->type_params) { - type_params.push_back(Fresh(type_param)); - } - tvm::Array params; - for (const Var& param : op->params) { - params.push_back(Fresh(param)); - } - return FunctionNode::make(params, - VisitExpr(op->body), - VisitType(op->ret_type), - type_params, - op->attrs); - } - - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } - - Clause VisitClause(const Clause& c) final { - Pattern pat = VisitPattern(c->lhs); - return ClauseNode::make(pat, VisitExpr(c->rhs)); - } - - Type VisitType_(const TypeVarNode* op) final { - TypeVar v = GetRef(op); - return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; - } - - Var VisitVar(const Var& v) final { - return Fresh(v); - } - - private: - std::unordered_map rename_; - std::unordered_map type_rename_; - }; - - Expr ret = DeDupMutator().VisitExpr(e); - CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size()); - return ret; -} - /*! \brief Remap multiple Var sharing the same Id into the same Var. */ Expr Remap(const Expr& e) { class RemapMutator : public ExprMutator, public PatternMutator { diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 1b4b642eea8cb..19bf2cb4dc85c 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -18,9 +18,9 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * - * \file to_anf.cc + * \file to_a_normal_form.cc * * \brief Turn implicit sharing into observable sharing. */ @@ -72,13 +72,16 @@ Scope LCA(Scope lhs, Scope rhs) { std::unordered_map CalcScope(const DependencyGraph& dg) { std::unordered_map expr_scope; + bool global_scope_used = false; Scope global_scope = std::make_shared(); for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { DependencyGraph::Node* n = *it; auto iit = n->parents.head; Scope s; if (iit == nullptr) { + CHECK(!global_scope_used); s = global_scope; + global_scope_used = true; } else { s = expr_scope.at(iit->value); iit = iit->next; @@ -88,13 +91,10 @@ std::unordered_map CalcScope(const DependencyGrap } expr_scope.insert({n, n->new_scope ? ChildScope(s) : s}); } + CHECK(global_scope_used); return expr_scope; } -bool IsPrimitiveFunction(const Expr& e) { - return e.as() && Downcast(e)->IsPrimitive(); -} - /* Special care is needed to handle local recursion. * Fill additionally take a (possibly null) Var argument, * If it is not null, Fill is required to bind the transformed result to that var. @@ -137,22 +137,26 @@ class Fill : ExprFunctor { Expr VisitExpr(const Expr& e, const Var& v) final { if (memo.count(e) == 0) { memo.insert({e, ExprFunctor::VisitExpr(e, v)}); + } else if (v.defined()) { + GetScope(e)->ll->Push(v, memo.at(e)); } - return memo.at(e); + auto ret = memo.at(e); + CHECK(IsAtomic(ret)); + return ret; } Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } - Expr Atomic(const Expr& orig, const Expr& now, const Var& v) { - return v.defined() ? GetScope(orig)->ll->Push(v, now) : now; + Expr Atomic(const Expr& e, const Var& v) { + return v.defined() ? GetScope(e)->ll->Push(v, e) : e; } Expr Compound(const Expr& orig, const Expr& now, const Var& v) { Var var = v.defined() ? v : - VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); + VarNode::make(std::string("x"), Type()); return GetScope(orig)->ll->Push(var, now); } @@ -205,7 +209,7 @@ class Fill : ExprFunctor { Expr VisitExpr_(const FunctionNode* f, const Var& v) final { Expr e = GetRef(f); Expr ret; - if (IsPrimitiveFunction(e)) { + if (f->IsPrimitive()) { ret = e; } else { ret = FunctionNode::make(f->params, @@ -231,22 +235,22 @@ class Fill : ExprFunctor { Expr VisitExpr_(const VarNode* vn, const Var& v) final { Expr e = GetRef(vn); - return Atomic(e, e, v); + return Atomic(e, v); } Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { GlobalVar gv = GetRef(gvn); - return Atomic(gv, gv, v); + return Atomic(gv, v); } Expr VisitExpr_(const OpNode* op, const Var& v) final { Expr e = GetRef(op); - return Atomic(e, e, v); + return Atomic(e, v); } Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { Expr e = GetRef(c); - return Atomic(e, e, v); + return Atomic(e, v); } Expr VisitExpr_(const MatchNode* m, const Var& v) final { @@ -294,11 +298,15 @@ Module ToANormalForm(const Module& m) { tvm::Map updates; auto funcs = m->functions; for (const auto& it : funcs) { + CHECK_EQ(FreeVars(it.second).size(), 0); Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second); - CHECK_EQ(FreeVars(ret).size(), 0); + CHECK_EQ(FreeVars(ret).size(), 0) + << AsText(ret) + << "should not has free vars: " + << FreeVars(ret); updates.Set(it.first, Downcast(ret)); } diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc new file mode 100644 index 0000000000000..830b3b25a69b8 --- /dev/null +++ b/src/relay/pass/to_cps.cc @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file to_cps.cc + * + * \brief Turn a program to continuation passing style. + * + * Given a fresh type variable 'answer', + * continuation passing style(CPS) convert every function of a -> b to a -> (b -> anwer) -> answer. + * + * That is, instead of returning the result directly, + * function will now call another function (called the continuation) + * and return that value as a result instead. + * + * Continuation passing style turn all function call into tail call, + * which bound the stack size, prevent stack from overflowing during recursion, + * and allow tail call optimization. + * + * In relay, as tensor operation is the bottleneck, + * CPS is currently intended to transform the program before partial eval (PE), + * as it reify the control flow and enable PE to handle control flow join more agressively. + * + * For example, given 'let a = if b then c else d in e', it will transform the code into + * 'let f a = e in if b then f c else f d'. + * This allow f to be optimized individually in both branch. + * + * We implement CPS conversion by higher order transform + * (see http://matt.might.net/articles/cps-conversion/). + * The basic idea is that we will recursively traverse the AST. + * During the traversal, there is an extra parameter, mcont, of expr -> expr. + * It is basically a continuation at the metalevel. + * All cases in the transform must return via the mcont, + * wheter directly invoking it, or indirectly by recursion. + */ +#include +#include +#include +#include "../ir/type_functor.h" +#include "let_list.h" +#include "pass_util.h" + +namespace tvm { +namespace relay { + +// we assume the data type has no closure - no idea how to look into datatype right now. + +Type Arrow(const Type& l, const Type& r) { + return FuncTypeNode::make({l}, r, {}, {}); +} + +Type CPSType(const Type& t, const TypeVar& answer); + +FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) { + tvm::Array new_arg_types; + for (const Type& t : f->arg_types) { + new_arg_types.push_back(CPSType(t, answer)); + } + new_arg_types.push_back(Arrow(CPSType(f->ret_type, answer), answer)); + return FuncTypeNode::make(new_arg_types, answer, f->type_params, f->type_constraints); +} + +Type CPSType(const Type& t, const TypeVar& answer) { + struct CPSTypeMutator : TypeMutator { + explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) { } + TypeVar answer; + Type VisitType_(const FuncTypeNode* t) final { + return CPSFuncType(GetRef(t), answer); + } + } mut(answer); + return mut(t); +} + +// transform global functions into cps form. +using CPSMap = std::unordered_map; + +// transform vars from the original program into new vars, so their type will be correct. +using VarMap = std::unordered_map; + +/* + * The meta continuation. + * There is 3 rules on the metacontinuation: + * 0: It can only use the argument once. + * The argument is code, and using it twice will duplicate code. + * Bound the argument via let instead. + * 1: If the size of the metacontinuation is unbounded, it can only be called once. + * It contain code, so calling it twice duplicate code. + * Reify the continuation and bound it instead. + * See the function 'reify' and the if case for more detail. + * 2: The argument must be effect free. + * It might reorder or drop the argument. + * Again, bound the argument via let instead. + * See the call case for more detail. + */ +using MCont = std::function; + +Function ToCPS(const Function& f, const Module& m, CPSMap* cm); + +Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const TypeVar& answer) { + std::function remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); }; + auto function_type = Downcast(f->checked_type()); + // Each MCont can be used at most once. + struct CPSFunctor : ExprFunctor, PatternMutator { + CPSFunctor(const std::function& remap, + const TypeVar& answer, + const Module& m, + VarMap* vm, + CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { } + const std::function& remap; + TypeVar answer; + Module m; + VarMap* vm; + CPSMap* cm; + + Expr VisitExpr_(const LetNode* op, const MCont& k) final { + return VisitExpr(op->value, [&](const Expr& v) { + return LetNode::make(remap(op->var), v, VisitExpr(op->body, k)); + }); + } + + Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { + CHECK(!op->IsPrimitive()) << "primitive func not supported yet."; + return k(ToCPS(GetRef(op), m, cm, vm, answer)); + } + + Expr VisitExpr_(const ConstantNode* op, const MCont& k) final { + return k(GetRef(op)); + } + + Expr VisitExpr_(const VarNode* op, const MCont& k) final { + return k(remap(GetRef(op))); + } + + Pattern VisitPattern_(const PatternVarNode* op) final { + return PatternVarNode::make(remap(op->var)); + } + + Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { + auto gv = GetRef(op); + if (cm->count(gv) == 0) { + auto cps_gv = GlobalVarNode::make(gv->name_hint + "_cps"); + cm->insert({gv, cps_gv}); + m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm)); + } + return k(cm->at(gv)); + } + + Expr VisitExpr_(const RefCreateNode* op, const MCont& k) final { + return VisitExpr(op->value, [&](const Expr& v) { return k(RefCreateNode::make(v)); }); + } + + Expr reify(const MCont& k) { + Var arg = VarNode::make("arg", Type()); + return FunctionNode::make({arg}, k(arg), Type(), {}, {}); + } + + Expr reify(const MCont& k, const std::function& cont) { + return LetList::Let(reify(k), + [&](const Var& f) { + return cont([&](const Expr& e) { return CallNode::make(f, {e}); }); + }); + } + + Expr VisitExpr_(const IfNode* op, const MCont& k) final { + return reify(k, [&](const MCont& kf) { + return VisitExpr(op->cond, + [&](const Expr& v) { + return IfNode::make(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf)); + }); + }); + } + + Expr VisitExpr_(const MatchNode* op, const MCont& k) final { + return reify(k, [&](const MCont& kf) { + return VisitExpr(op->data, [&](const Expr& v) { + tvm::Array clauses; + for (const auto& c : op->clauses) { + clauses.push_back(ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs, kf))); + } + return MatchNode::make(v, clauses); + }); + }); + } + + Expr VisitExpr_(const RefReadNode* op, const MCont& k) final { + return VisitExpr(op->ref, + [&](const Expr& r) { + return LetList::Let(RefReadNode::make(r), k); + }); + } + + Expr VisitExpr_(const RefWriteNode* op, const MCont& k) final { + return VisitExpr(op->ref, + [&](const Expr& r) { + return VisitExpr(op->value, + [&](const Expr& v) { + return LetList::Let(RefWriteNode::make(r, v), k); + }); + }); + } + + Expr VisitExpr_(const TupleNode* op, const MCont& k) final { + tvm::Array fields; + std::function next; + next = [&]() { + return (fields.size() == op->fields.size()) ? + k(TupleNode::make(fields)) : + VisitExpr(op->fields[fields.size()], [&](const Expr& v) { + fields.push_back(v); + return next(); + }); + }; + return next(); + } + + Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final { + return VisitExpr(op->tuple, [&](const Expr& v) { + return k(TupleGetItemNode::make(v, op->index)); + }); + } + + Expr VisitExpr_(const CallNode* op, const MCont& k) final { + if (op->op.as() || op->op.as()) { + tvm::Array args; + std::function next; + next = [&]() { + if (args.size() == op->args.size()) { + return LetList::Let(CallNode::make(op->op, args, op->attrs, op->type_args), k); + } else { + return VisitExpr(op->args[args.size()], [&](const Expr& v) { + args.push_back(v); + return next(); + }); + } + }; + return next(); + } else { + Expr f; + tvm::Array args; + std::function next; + next = [&]() { + if (args.size() == op->args.size()) { + args.push_back(reify(k)); + return Expr(CallNode::make(f, args, op->attrs, op->type_args)); + } else { + return VisitExpr(op->args[args.size()], [&](const Expr& v) { + args.push_back(v); + return next(); + }); + } + }; + return VisitExpr(op->op, [&](const Expr& v) { + f = v; + return next(); + }); + } + } + } mut(remap, answer, m, vm, cm); + Var k = VarNode::make("k", Arrow(CPSType(function_type->ret_type, answer), answer)); + tvm::Array new_params; + for (const Var& v : f->params) { + new_params.push_back(remap(v)); + } + new_params.push_back(k); + return FunctionNode::make(new_params, + mut.VisitExpr(f->body, + [&](const Expr& e) { return CallNode::make(k, {e}); }), + answer, + f->type_params, + f->attrs); +} + +Function ToCPS(const Function& f, const Module& m, CPSMap* cm) { + TypeVar answer = TypeVarNode::make("answer", kType); + VarMap var; + struct Remapper : ExprVisitor, PatternVisitor { + Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { } + TypeVar answer; + VarMap* vm; + void VisitExpr_(const VarNode* vn) final { + Var v = GetRef(vn); + if (vm->count(v) == 0) { + auto ret = VarNode::make(v->name_hint(), CPSType(v->checked_type(), answer)); + vm->insert({v, ret}); + } + } + + void VisitPattern(const Pattern& p) final { + PatternVisitor::VisitPattern(p); + } + + void VisitPattern_(const PatternVarNode* op) final { + VisitExpr(op->var); + } + } remap(answer, &var); + remap.VisitExpr(f); + Function ret = ToCPS(f, m, cm, &var, answer); + auto new_type_params = ret->type_params; + new_type_params.push_back(answer); + return FunctionNode::make(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); +} + +Function ToCPS(const Function& f, const Module& m) { + CPSMap cps; + return ToCPS(f, m, &cps); +} + +Function UnCPS(const Function& f) { + CHECK_GT(f->params.size(), 0); + std::vector new_params; + for (const auto& p : f->params) { + new_params.push_back(VarNode::make(p->name_hint(), p->checked_type())); + } + auto cont_type = Downcast(new_params.back()->type_annotation); + new_params.pop_back(); + CHECK_EQ(cont_type->arg_types.size(), 1); + auto new_ret_type = Type(cont_type->arg_types[0]); + std::vector new_type_params; + for (const auto& tp : f->type_params) { + new_type_params.push_back(TypeVarNode::make(tp->var->name_hint, tp->kind)); + } + auto answer_type = new_type_params.back(); + new_type_params.pop_back(); + // TODO(@M.K.): make alphaequal work on free term + // CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type))); + auto x = VarNode::make("x", new_ret_type); + auto cont = FunctionNode::make({x}, x, new_ret_type, {}, {}); + tvm::Array args; + for (const auto& p : new_params) { + args.push_back(p); + } + args.push_back(cont); + tvm::Array type_args; + for (const auto& tp : new_type_params) { + type_args.push_back(tp); + } + type_args.push_back(new_ret_type); + return FunctionNode::make(new_params, + CallNode::make(f, args, {}, type_args), + new_ret_type, + new_type_params, + f->attrs); +} + +TVM_REGISTER_API("relay._transform.to_cps") +.set_body_typed(static_cast(ToCPS)); + +TVM_REGISTER_API("relay._transform.un_cps") +.set_body_typed(UnCPS); + +namespace transform { + +Pass ToCPS() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Function(ToCPS(f, m)); + }; + return CreateFunctionPass(pass_func, 1, "ToCPS", {}); +} + +TVM_REGISTER_API("relay._transform.ToCPS") +.set_body_typed(ToCPS); + + +Pass UnCPS() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Function(UnCPS(f)); + }; + return CreateFunctionPass(pass_func, 1, "UnCPS", {}); +} + +TVM_REGISTER_API("relay._transform.UnCPS") +.set_body_typed(UnCPS); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ff356cb9c9ef8..aa3cc029ef69c 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -368,10 +368,14 @@ class TypeInferencer : private ExprFunctor, // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. - for (size_t i = 0; i < fn_ty->type_params.size(); i++) { + for (size_t i = 0; i < ty_args.size(); ++i) { subst_map.Set(fn_ty->type_params[i], ty_args[i]); } + for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) { + subst_map.Set(fn_ty->type_params[i], IncompleteTypeNode::make(Kind::kType)); + } + Type ret_type = fn_ty->ret_type; // If the function type is incomplete, place a new IncompleteType @@ -437,13 +441,7 @@ class TypeInferencer : private ExprFunctor, } Array type_args = call->type_args; - if (type_args.size() == 0) { - for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) { - type_args.push_back(IncompleteTypeNode::make(Kind::kType)); - } - } - - if (type_args.size() != fn_ty_node->type_params.size()) { + if (type_args.size() > fn_ty_node->type_params.size()) { this->ReportFatalError(GetRef(call), RELAY_ERROR("Incorrect number of type args in " << call->span << ": " diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 8d358e3f805f4..b2f7b9340ad42 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -17,14 +17,7 @@ import tvm from tvm import relay from tvm.relay import transform - - -def run_opt_pass(expr, opt_pass): - assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) - mod = opt_pass(mod) - entry = mod[mod.entry_func] - return entry if isinstance(expr, relay.Function) else entry.body +from tvm.relay.testing import run_opt_pass def test_fuse_simple(): diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 400f5d79b1e40..555e418644bcf 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -22,15 +22,7 @@ from tvm.relay import create_executor, transform from tvm.relay.transform import gradient from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions, make_nat_expr - - -def run_infer_type(expr): - mod = relay.Module.from_expr(expr) - mod = relay.Module.from_expr(expr) - mod = transform.InferType()(mod) - entry = mod[mod.entry_func] - return entry if isinstance(expr, relay.Function) else entry.body +from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type def rand(dtype='float32', *shape): diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index c12298e465df3..51b8793f7667d 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -186,6 +186,19 @@ def test_function(): check_eval(anf_f(d), 8) +def test_gradient_if(): + x = relay.var("a", shape=(1, 16)) + y = relay.var("y", shape=(1, 16)) + cond = relay.var("cond", shape=(), dtype='uint1') + net = relay.If(cond, x, x) + net = relay.add(x, net) + net = relay.Function([cond,x,y], net) + mod = relay.Module.from_expr(net) + mod = relay.transform.ToANormalForm()(mod) + mod[mod.entry_func] = relay.transform.gradient(mod[mod.entry_func], mode='higher_order') + mod = relay.transform.ToANormalForm()(mod) + + if __name__ == '__main__': test_explicit_bound() test_order() @@ -195,3 +208,4 @@ def test_function(): test_let() test_nat_add() test_function() + test_gradient_if() diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py new file mode 100644 index 0000000000000..128fc49b58ca9 --- /dev/null +++ b/tests/python/relay/test_pass_to_cps.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relay +from tvm.relay.analysis import alpha_equal, detect_feature +from tvm.relay.transform import to_cps, un_cps +from tvm.relay.feature import Feature +from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass +from tvm.relay import create_executor +from tvm.relay import Function, transform + + +def rand(dtype='float32', *shape): + return tvm.nd.array(np.random.rand(*shape).astype(dtype)) + + +# make sure cps work for recursion. +def test_recursion(): + mod = relay.Module() + p = Prelude(mod) + add_nat_definitions(p) + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + double = relay.Function([x], x + x) + i = relay.var("i", t) + func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) + mod[mod.entry_func] = func + mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod) + mod[mod.entry_func] = un_cps(mod[mod.entry_func]) + ex = create_executor(mod=mod) + i_nd = rand(dtype, *shape) + forward = ex.evaluate(mod.entry_func)(i_nd) + tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy()) + + +# This serve as an integration test. +# It test that, given a program with reference, +# cps and pe can completely eliminate the allocation of reference. +def test_cps_pe(): + def destroy_ref(x): + x = run_infer_type(x) + x = to_cps(x) + x = run_infer_type(x) + y = un_cps(x) + y = run_infer_type(y) + x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])) + assert Feature.fRefCreate not in detect_feature(x) + unit = relay.Function([], relay.const(0., dtype='float32')) + f_ref = relay.Var("f_ref") + + one = relay.const(1., dtype='float32') + two = relay.const(2., dtype='float32') + cond = relay.var(shape=(), dtype='uint1', name_hint='cond') + true_branch = relay.RefWrite(f_ref, relay.Function([], one)) + false_branch = relay.RefWrite(f_ref, relay.Function([], two)) + if_expr = relay.If(cond, true_branch, false_branch) + + stmt = relay.Let(f_ref, relay.RefCreate(unit), + relay.Let(relay.Var("x"), if_expr, + relay.Call(relay.RefRead(f_ref), []))) + + F = relay.Function([cond], stmt) + destroy_ref(F) + + G = relay.Function([cond], relay.If(cond, one, two)) + G = relay.transform.gradient(G) + destroy_ref(G) + + x = relay.var("x", shape=(1, 16)) + y = relay.var("y", shape=(1, 16)) + z = relay.var("z", shape=(1, 16)) + cond = relay.var("cond", shape=(), dtype='uint1') + H = relay.If(cond, x, y) + H = relay.add(H, z) + H = relay.Function([cond,x,y,z], H) + H = relay.transform.gradient(H) + destroy_ref(H) + + +if __name__ == '__main__': + test_recursion() + test_cps_pe() From 882ae1267cf648a3c818c21bc964cbc2495f2548 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Tue, 2 Jul 2019 20:25:49 -0700 Subject: [PATCH 12/26] producing simulation statistics instead of time to get useful information out of simulation runs (#3481) --- .../frontend/deploy_resnet_on_vta.py | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/vta/tutorials/frontend/deploy_resnet_on_vta.py b/vta/tutorials/frontend/deploy_resnet_on_vta.py index 3e252172444a0..3035decb2160d 100644 --- a/vta/tutorials/frontend/deploy_resnet_on_vta.py +++ b/vta/tutorials/frontend/deploy_resnet_on_vta.py @@ -229,25 +229,39 @@ m.set_input(**params) m.set_input('data', image) -# Perform inference: we run the module 4 times, -# and repeat 3 times to get error bounds -timer = m.module.time_evaluator("run", ctx, number=4, repeat=3) -tcost = timer() +# Perform inference and gather execution statistics +# More on: https://docs.tvm.ai/api/python/module.html#tvm.module.Module.time_evaluator +num = 4 # number of times we run module for a single measurement +rep = 3 # number of measurements (we derive std dev from this) +timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep) + +if env.TARGET == "sim": + simulator.clear_stats() + timer() + sim_stats = simulator.stats() + print("\nExecution statistics:") + for k, v in sim_stats.items(): + # Since we execute the workload many times, we need to normalize stats + # Note that there is always one warm up run + # Therefore we divide the overall stats by (num * rep + 1) + print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1))) +else: + tcost = timer() + std = np.std(tcost.results) * 1000 / env.BATCH + mean = tcost.mean * 1000 / env.BATCH + print("\nPerformed inference in %.2fms/sample (std = %.2f)" % (mean, std)) # Get classification results tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0))) top_categories = np.argsort(tvm_output.asnumpy()[0]) # Report top-5 classification results -std = np.std(tcost.results) * 1000 / env.BATCH -mean = tcost.mean * 1000 / env.BATCH -print("%s prediction" % model) -print(" #1:", synset[top_categories[-1]]) -print(" #2:", synset[top_categories[-2]]) -print(" #3:", synset[top_categories[-3]]) -print(" #4:", synset[top_categories[-4]]) -print(" #5:", synset[top_categories[-5]]) -print("Performed inference in %.2fms/sample (std = %.2f)" % (mean, std)) +print("\n%s prediction" % model) +print("\t#1:", synset[top_categories[-1]]) +print("\t#2:", synset[top_categories[-2]]) +print("\t#3:", synset[top_categories[-3]]) +print("\t#4:", synset[top_categories[-4]]) +print("\t#5:", synset[top_categories[-5]]) # This just checks that one of the 5 top categories # is one variety of cat; this is by no means an accurate From f3dcab42073848bff430ce22e3c0e91cdf93daaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Wed, 3 Jul 2019 07:21:44 -0700 Subject: [PATCH 13/26] [Relay] use transform instead of ir_pass for CPS (#3485) --- python/tvm/relay/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index f77a532ba7387..2805e0b429fa0 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -457,7 +457,7 @@ def ToCPS(expr, mod=None): result: tvm.relay.Pass The registered pass that transforms an expression into CPS. """ - return _ir_pass.to_cps(expr, mod) + return _transform.to_cps(expr, mod) def EtaExpand(): From 287078c33db85d4f312d8d2457a064442d9d18c3 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 3 Jul 2019 10:08:04 -0700 Subject: [PATCH 14/26] Pre-allocate buffer for x86 roi_align (#3475) * Pre-allocate buffer for x86 roi_align * Fix typo --- topi/python/topi/x86/roi_align.py | 44 +++++++++++++++++++-------- topi/tests/python/test_topi_vision.py | 1 + 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/topi/python/topi/x86/roi_align.py b/topi/python/topi/x86/roi_align.py index a8ad387a242f3..26b84be9585b3 100644 --- a/topi/python/topi/x86/roi_align.py +++ b/topi/python/topi/x86/roi_align.py @@ -16,14 +16,17 @@ # under the License. # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements """Non-maximum suppression operator for intel cpu""" +import math import tvm from tvm import hybrid from ..vision.rcnn import roi_align_nchw +from ..tensor import full +from ..util import get_const_tuple @hybrid.script -def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio): +def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio): """Hybrid routing fo ROI align operator in NCHW layout. Parameters @@ -35,6 +38,12 @@ def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio): 2-D with shape [num_roi, 5]. The last dimension should be in format of [batch_index, w_start, h_start, w_end, h_end] + w_pc : tvm.Tensor or numpy NDArray + 3-D weight pre-calculation buffer + + pos_pc : tvm.Tensor or numpy NDArray + 3-D position pre-calculation buffer + pooled_size : tvm ConsExpr [out_height, out_width] @@ -57,9 +66,6 @@ def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio): pooled_size_h = pooled_size[0] pooled_size_w = pooled_size[1] output = output_tensor((num_rois, channels, pooled_size_h, pooled_size_w), data.dtype) - max_num_pc_index = height * width * pooled_size_h * pooled_size_w - w_pc = allocate((num_rois, max_num_pc_index, 4), data.dtype) - pos_pc = allocate((num_rois, max_num_pc_index, 4), "int32") for n in parallel(num_rois): roi_batch_index = int32(rois[n, 0]) @@ -76,18 +82,16 @@ def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio): roi_bin_grid_h = sample_ratio roi_bin_grid_w = roi_bin_grid_h - div_h = roi_h / pooled_size_h - div_w = roi_w / pooled_size_w - rounded_div_h = int32(div_h) * 1.0 - rounded_div_w = int32(div_w) * 1.0 + rounded_bin_h = int32(bin_h) * 1.0 + rounded_bin_w = int32(bin_w) * 1.0 if sample_ratio <= 0: # Cannot use ceil function since hybrid script # doesn't support Call as indexing - roi_bin_grid_h = int32(div_h) - roi_bin_grid_w = int32(div_w) - if rounded_div_h < div_h: + roi_bin_grid_h = int32(bin_h) + roi_bin_grid_w = int32(bin_w) + if rounded_bin_h < bin_h: roi_bin_grid_h += 1 - if rounded_div_w < div_w: + if rounded_bin_w < bin_w: roi_bin_grid_w += 1 count = roi_bin_grid_h * roi_bin_grid_w @@ -211,7 +215,21 @@ def roi_align_nchw_cpu(data, rois, pooled_size, spatial_scale, sample_ratio=-1): """ if not isinstance(pooled_size, (tuple, list)): pooled_size = (pooled_size, pooled_size) + + # Pre-allocate intermediate buffer + if sample_ratio > 0: + max_roi_bin_grid_w = max_roi_bin_grid_h = sample_ratio + else: + _, _, height, width = get_const_tuple(data.shape) + max_roi_bin_grid_h = math.ceil(height / pooled_size[0]) + max_roi_bin_grid_w = math.ceil(width / pooled_size[1]) + max_pc_shape = (rois.shape[0], max_roi_bin_grid_h * max_roi_bin_grid_w + * pooled_size[0] * pooled_size[1], 4) + w_pc_buffer = full(max_pc_shape, data.dtype, 0) + pos_pc_buffer = full(max_pc_shape, "int32", 0) + pooled_size = tvm.convert(pooled_size) spatial_scale = tvm.const(spatial_scale, "float32") sample_ratio = tvm.const(sample_ratio, "int32") - return roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio) + return roi_align_nchw_ir(data, rois, w_pc_buffer, pos_pc_buffer, + pooled_size, spatial_scale, sample_ratio) diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 3a0b134890379..08b6d2e7d4148 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -306,6 +306,7 @@ def test_roi_align(): verify_roi_align(1, 16, 32, 64, 7, 1.0, -1) verify_roi_align(4, 16, 32, 64, 7, 0.5, 2) verify_roi_align(1, 32, 32, 80, 8, 0.0625, 2) + verify_roi_align(1, 32, 500, 80, 8, 0.0625, 2) def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale): From b3f3ab5593c1949947c9872c8df1479975116a95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Wed, 3 Jul 2019 21:58:51 -0700 Subject: [PATCH 15/26] [Relay] Fix PE (#3482) --- include/tvm/relay/module.h | 2 +- src/relay/ir/expr_functor.cc | 26 ++++- src/relay/ir/type_functor.cc | 6 +- src/relay/ir/type_functor.h | 1 + src/relay/pass/let_list.h | 2 +- src/relay/pass/partial_eval.cc | 116 ++++++++++--------- src/relay/pass/type_infer.cc | 1 + src/relay/pass/util.cc | 10 ++ tests/python/relay/test_pass_partial_eval.py | 4 +- 9 files changed, 101 insertions(+), 67 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 638f75968fd33..4a3ff0b6eb19c 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -55,7 +55,7 @@ struct Module; * The functional style allows users to construct custom * environments easily, for example each thread can store * a Module while auto-tuning. - * */ + */ class ModuleNode : public RelayNode { public: diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 36692c5c571b1..0434e2ac59c64 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file src/tvm/relay/expr_mutator.cc * \brief A wrapper around ExprFunctor which functionally updates the AST. * @@ -26,6 +26,7 @@ * the cost of using functional updates. */ #include +#include #include "type_functor.h" namespace tvm { @@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._analysis.post_order_visit") }); // Implement bind. -class ExprBinder : public ExprMutator { +class ExprBinder : public ExprMutator, PatternMutator { public: explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) { @@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator { } } + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Clause VisitClause(const Clause& c) final { + Pattern pat = VisitPattern(c->lhs); + return ClauseNode::make(pat, VisitExpr(c->rhs)); + } + + Var VisitVar(const Var& v) final { + return Downcast(VisitExpr(v)); + } + private: const tvm::Map& args_map_; }; Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { - Expr new_body = ExprBinder(args_map).Mutate(func->body); + Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array new_params; for (Var param : func->params) { if (!args_map.count(param)) { @@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { func->type_params, func->attrs); } else { - return ExprBinder(args_map).Mutate(expr); + return ExprBinder(args_map).VisitExpr(expr); } } diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 516f4c875b20c..cde68c50daeff 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { } } +Type TypeMutator::VisitType(const Type& t) { + return t.defined() ? TypeFunctor::VisitType(t) : t; +} + // Type Mutator. Array TypeMutator::MutateArray(Array arr) { // The array will do copy on write @@ -221,7 +225,7 @@ class TypeBinder : public TypeMutator { }; Type Bind(const Type& type, const tvm::Map& args_map) { - return type.defined() ? TypeBinder(args_map).VisitType(type) : type; + return TypeBinder(args_map).VisitType(type); } } // namespace relay diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index 27ac288fe48db..c3ee14eedd487 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor { // Mutator that transform a type to another one. class TypeMutator : public TypeFunctor { public: + Type VisitType(const Type& t) override; Type VisitType_(const TypeVarNode* op) override; Type VisitType_(const TensorTypeNode* op) override; Type VisitType_(const IncompleteTypeNode* op) override; diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 1b422d2a878f0..73c5fe3abc22c 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -48,7 +48,7 @@ class LetList { public: ~LetList() { if (lets_.size() > 0 && !used_) { - std::cout << "Warning: letlist not used" << std::endl; + LOG(WARNING) << "letlist not used"; } } /*! diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 6887c7a603227..b7f12b65751db 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -64,7 +64,7 @@ * 3: The generated code reuses bindings (although they are not shadowed), * so we have to deduplicate them. * - * 4: In the generated code, multiple VarNode might have same Id. + * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id. * While it is permitted, most pass use NodeHash for Var, * and having multiple VarNode for same Id break them. * Thus we remap them to a single Id for now. @@ -216,9 +216,9 @@ Static MkSRef() { } using Func = std::function&, - const Attrs&, - const Array&, - LetList*)>; + const Attrs&, + const Array&, + LetList*)>; struct SFuncNode : StaticNode { Func func; @@ -256,6 +256,7 @@ class Environment { void Insert(const Var& v, const PStatic& ps) { CHECK(ps.defined()); + CHECK_EQ(env_.back().locals.count(v), 0); env_.back().locals[v] = ps; } @@ -287,12 +288,17 @@ class Environment { /*! * \brief As our store require rollback, we implement it as a frame. - * every time we need to copy the store, a new frame is insert. - * every time we roll back, a frame is popped. + * + * Every time we need to copy the store, a new frame is insert. + * Every time we roll back, a frame is popped. */ struct StoreFrame { std::unordered_map store; - /*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */ + /*! + * \brief On unknown effect, history_valid is set to true to signal above frame is outdated. + * + * It only outdate the frame above it, but not the current frame. + */ bool history_valid = true; explicit StoreFrame(const std::unordered_map& store) : store(store) { } StoreFrame() = default; @@ -310,6 +316,7 @@ class Store { } void Insert(const SRefNode* r, const PStatic& ps) { + CHECK(r); store_.back().store[r] = ps; } @@ -317,19 +324,21 @@ class Store { PStatic Lookup(const SRefNode* r) { auto rit = store_.rbegin(); while (rit != store_.rend()) { - if (!rit->history_valid) { - return PStatic(); - } if (rit->store.find(r) != rit->store.end()) { return rit->store.find(r)->second; } + if (!rit->history_valid) { + return PStatic(); + } ++rit; } return PStatic(); } void Invalidate() { - store_.back().history_valid = false; + StoreFrame sf; + sf.history_valid = false; + store_.push_back(sf); } private: @@ -341,6 +350,10 @@ class Store { store_->store_.push_back(StoreFrame()); } ~StoreFrameContext() { + // push one history valid frame off. + while (!store_->store_.back().history_valid) { + store_->store_.pop_back(); + } store_->store_.pop_back(); } }; @@ -442,13 +455,7 @@ Function AsFunc(const Expr& e) { class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const tvm::Array& free_vars, - const Module& mod) : - mod_(mod) { - for (const Var& v : free_vars) { - env_.Insert(v, NoStatic(v)); - } - } + PartialEvaluator(const Module& mod) : mod_(mod) { } PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic ret = ExprFunctor::VisitExpr(e, ll); @@ -484,23 +491,23 @@ class PartialEvaluator : public ExprFunctor return env_.Lookup(GetRef(op)); } - PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { - GlobalVar gv = GetRef(op); + PStatic VisitGlobalVar(const GlobalVar& gv) { + CHECK(mod_.defined()); if (gv_map_.count(gv) == 0) { - if (mod_.defined()) { - Function func = mod_->Lookup(gv); - InitializeFuncId(func); - Func f = VisitFuncStatic(func, gv); - gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); - func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); - mod_->Update(gv, func); - } else { - gv_map_.insert({gv, NoStatic(gv)}); - } + Function func = mod_->Lookup(gv); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + mod_->Update(gv, func); } return gv_map_.at(gv); } + PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { + return VisitGlobalVar(GetRef(op)); + } + PStatic VisitExpr_(const LetNode* op, LetList* ll) final { env_.Insert(op->var, VisitExpr(op->value, ll)); return VisitExpr(op->body, ll); @@ -629,7 +636,7 @@ class PartialEvaluator : public ExprFunctor subst.Set(func->type_params[i], type_args[i]); } for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { - subst.Set(func->type_params[i], Type()); + subst.Set(func->type_params[i], IncompleteTypeNode::make(kType)); } std::vector