From 5d70b008668ba244dad598d1ce3e8a79e2d6cede Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Apr 2019 13:20:56 +0800 Subject: [PATCH 1/7] [Relay] InferCorrectLayout for strided_slice & min_num_branches option in CombineParallelConv2D (#2961) * [Relay] InferCorrectLayout for strided_slice * Add min_num_branches option to CombineParallelConv2D * Return undef if original layout contains splitted axes --- python/tvm/relay/ir_pass.py | 9 ++- src/relay/op/tensor/transform.cc | 61 ++++++++++++++++++- src/relay/pass/combine_parallel_conv2d.cc | 15 ++++- .../python/relay/test_pass_alter_op_layout.py | 43 +++++++++++++ .../test_pass_combine_parallel_conv2d.py | 8 +-- 5 files changed, 125 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 8eb0adc3da1a..b3d323b2aed6 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -722,20 +722,23 @@ def fuse_ops(expr, opt_level=1): return _ir_pass.FuseOps(expr, opt_level) -def combine_parallel_conv2d(expr): - """Fold multiple conv2d into one. +def combine_parallel_conv2d(expr, min_num_branches=3): + """Combine multiple conv2d into one. Parameters ---------- expr : tvm.relay.Expr The input expression. + min_num_branches : int + The minimum number of parallel branches when the transformation should be applied. + Returns ------- transformed_expr : tvm.relay.Expr Transformed expression """ - return _ir_pass.CombineParallelConv2D(expr) + return _ir_pass.CombineParallelConv2D(expr, min_num_branches) def alter_op_layout(expr): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 15eaceb41a2d..f86156bdbddc 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1722,6 +1722,64 @@ bool StridedSliceRel(const Array& types, } +Array > StridedSliceInferCorrectLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array>& old_in_shapes) { + CHECK(old_in_layouts.defined()); + CHECK_EQ(old_in_layouts.size(), 1); + CHECK(old_in_shapes.defined()); + CHECK_EQ(old_in_shapes.size(), 1); + + auto layout = old_in_layouts[0]; + if (layout.defined() && new_in_layouts.defined()) { + CHECK_EQ(new_in_layouts.size(), 1); + auto new_layout = new_in_layouts[0]; + auto shape = old_in_shapes[0]; + + // NOTE: Discard "const" qualifier here. + auto *params = const_cast(attrs.as()); + + Array new_begin, new_end; + + for (size_t i = 0; i < params->begin.size(); i++) { + const LayoutAxis& axis = layout[i]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + auto factor = new_layout.FactorOf(axis); + if (factor == -1) { + new_begin.push_back(params->begin[i]); + new_end.push_back(params->end[i]); + } else { + if (params->strides.defined() && i < params->strides.size()) { + auto stride = params->strides[i]; + // arbitrary stride is not supported + if (stride.defined() && stride->value != 1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } + } + int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0; + int64_t end = params->end[i].defined() ? params->end[i]->value : + shape[i].as()->value; + if (begin % factor || end % factor) { + // transform to original layout + return {{Layout::Undef()}, {Layout::Undef()}}; + } + new_begin.push_back(tvm::Integer(begin / factor)); + new_end.push_back(tvm::Integer(end / factor)); + } + } + layout = new_layout; + params->begin = new_begin; + params->end = new_end; + } + return {{layout}, {layout}}; +} + + // Positional relay function to create StridedSlice operator used by frontend FFI. Expr MakeStridedSlice(Expr data, Array begin, @@ -1783,7 +1841,8 @@ Examples:: .set_attrs_type_key("relay.attrs.StridedSliceAttrs") .add_type_rel("StridedSlice", StridedSliceRel) .set_attr("FTVMCompute", StridedSliceCompute) -.set_attr("TOpPattern", kInjective); +.set_attr("TOpPattern", kInjective) +.set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); // relay.split diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index cb53698762ad..cd7a852bcad7 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -159,10 +159,15 @@ class BranchGroupFinder : private ExprVisitor { class ParallelConv2DCombiner { public: + explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) { + } + Expr Combine(const Expr& expr) { auto groups = BranchGroupFinder().Find(expr); for (const Group& group : groups) { - if (group.size() < 2) continue; + if (group.size() < min_num_branches_) { + continue; + } CombineBranches(group); } return ExprSubst(expr, std::move(subst_map_)); @@ -170,6 +175,7 @@ class ParallelConv2DCombiner { private: std::unordered_map subst_map_; + uint64_t min_num_branches_; std::tuple TransformWeight(const Group& branches) { int64_t num_filters = 0; // number of filters of the transformed weight @@ -343,11 +349,14 @@ class ParallelConv2DCombiner { } }; -Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); } +/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */ +Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { + return ParallelConv2DCombiner(min_num_branches).Combine(expr); +} TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = CombineParallelConv2D(args[0]); + *ret = CombineParallelConv2D(args[0], args[1]); }); } // namespace relay diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 0f21288245d9..f7a1c83ddff1 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -472,6 +472,48 @@ def expected(): assert(alpha_equal(a, b)) +def test_alter_layout_strided_slice(): + """Test rewriting strided_slice during alter_iop_layout""" + def before(): + x = relay.var("x", shape=(1, 32, 28, 28)) + weight = relay.var('weight', shape=(32, 32, 3, 3)) + y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) + y = relay.strided_slice(y, begin=[0, 16], end=[None, None]) + y = relay.Function(free_vars(y), y) + return y + + @register_alter_op_layout("nn.conv2d", level=109) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW4c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 32, 28, 28)) + weight = relay.var("weight") + x = relay.layout_transform(x, "NCHW", "NCHW4c") + y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), + data_layout="NCHW4c") + y = relay.strided_slice(y, begin=[0, 4], end=[None, 8]) + y = relay.layout_transform(y, "NCHW4c", "NCHW") + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = canonicalize_ops(a) + a = infer_type(a) + + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() @@ -482,3 +524,4 @@ def expected(): test_alter_layout_scalar() test_alter_layout_concatenate() test_alter_layout_nchw_upsamping_op() + test_alter_layout_strided_slice() diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 0d6e1e39b509..3bb656b2bda5 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -55,7 +55,7 @@ def check(x_shape, channels1, channels2, channels3, channels4): y_before = before(x, w1, w2, w3, w4) y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) y = relay.ir_pass.infer_type(y) y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) y_expected = relay.ir_pass.infer_type(y_expected) @@ -102,7 +102,7 @@ def check(x_shape, channels1, channels2): bias = relay.var("bias", shape=(channels2, 1, 1)) y_before = before(x, w1, w2, scale1, scale2, bias) y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) y = relay.ir_pass.infer_type(y) y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) y_expected = relay.ir_pass.infer_type(y_expected) @@ -142,7 +142,7 @@ def check(x_shape, channels1, channels2): scale2 = relay.var("scale2", shape=(1,)) y_before = before(x, w1, w2, scale1, scale2) y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) y = relay.ir_pass.infer_type(y) y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) y_expected = relay.ir_pass.infer_type(y_expected) @@ -179,7 +179,7 @@ def check(x_shape, repeat): w = relay.var("w", shape=(out_c, in_c, 1, 1)) y_before = before(x, w, repeat) y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) y = relay.ir_pass.infer_type(y) y_expected = expected(x, w, out_c, repeat) y_expected = relay.ir_pass.infer_type(y_expected) From 28f354bf1efc9e3084cff68a0d53866a4ec6d88f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Mon, 8 Apr 2019 22:51:00 -0700 Subject: [PATCH 2/7] [Relay] Add expr_visitor, fix expr_functor exponential blowup problem (#2988) * save * lint --- python/tvm/relay/__init__.py | 1 + python/tvm/relay/expr_functor.py | 67 +++++++++++++++++++++++-- tests/python/relay/test_expr_functor.py | 28 ++++++++++- 3 files changed, 91 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 1ec09aac2606..2ab4ca2e1404 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -101,6 +101,7 @@ # ExprFunctor ExprFunctor = expr_functor.ExprFunctor +ExprVisitor = expr_functor.ExprVisitor ExprMutator = expr_functor.ExprMutator # Parser diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index a924847fa238..9ca094158c1c 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -36,9 +36,8 @@ def __init__(self): # pylint: disable=no-else-return def visit(self, expr): """Apply the visitor to an expression.""" - found = self.memo_map.get(expr) - if found: - return found + if expr in self.memo_map: + return self.memo_map[expr] if isinstance(expr, Function): res = self.visit_function(expr) @@ -126,6 +125,68 @@ def visit_match(self, _): raise NotImplementedError() +class ExprVisitor(ExprFunctor): + """ + A visitor over Expr. + + The default behavior recursively traverses the AST. + """ + def visit_tuple(self, t): + for x in t.fields: + self.visit(x) + + def visit_call(self, c): + self.visit(c.op) + for a in c.args: + self.visit(a) + + def visit_var(self, v): + pass + + def visit_let(self, l): + self.visit(l.var) + self.visit(l.value) + self.visit(l.body) + + def visit_function(self, f): + self.visit(f.body) + + def visit_if(self, i): + self.visit(i.cond) + self.visit(i.true_branch) + self.visit(i.false_branch) + + def visit_global_var(self, gv): + pass + + def visit_constructor(self, c): + pass + + def visit_op(self, op): + pass + + def visit_constant(self, const): + pass + + def visit_ref_create(self, r): + self.visit(r.value) + + def visit_ref_read(self, r): + self.visit(r.ref) + + def visit_ref_write(self, r): + self.visit(r.ref) + self.visit(r.value) + + def visit_tuple_getitem(self, t): + self.visit(t.tuple_value) + + def visit_match(self, m): + self.visit(m.data) + for c in m.clause: + self.visit(c.rhs) + + class ExprMutator(ExprFunctor): """ A functional visitor over Expr. diff --git a/tests/python/relay/test_expr_functor.py b/tests/python/relay/test_expr_functor.py index 2a58c282b4c7..ae5ee7bd8bd4 100644 --- a/tests/python/relay/test_expr_functor.py +++ b/tests/python/relay/test_expr_functor.py @@ -16,34 +16,42 @@ # under the License. import tvm from tvm import relay -from tvm.relay import ExprFunctor, ExprMutator +from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor def check_visit(expr): - ef = ExprFunctor() try: + ef = ExprFunctor() ef.visit(expr) assert False except NotImplementedError: pass + ev = ExprVisitor() + ev.visit(expr) + em = ExprMutator() assert em.visit(expr) + def test_constant(): check_visit(relay.const(1.0)) + def test_tuple(): t = relay.Tuple([relay.var('x', shape=())]) check_visit(t) + def test_var(): v = relay.var('x', shape=()) check_visit(v) + def test_global(): v = relay.GlobalVar('f') check_visit(v) + def test_function(): x = relay.var('x', shape=()) y = relay.var('y', shape=()) @@ -61,12 +69,14 @@ def test_function(): ) check_visit(f) + def test_call(): x = relay.var('x', shape=()) y = relay.var('y', shape=()) call = relay.op.add(x, y) check_visit(call) + def test_let(): x = relay.var('x', shape=()) value = relay.const(2.0) @@ -74,30 +84,43 @@ def test_let(): l = relay.Let(x, value, body) check_visit(l) + def test_ite(): cond = relay.var('x', shape=(), dtype='bool') ite = relay.If(cond, cond, cond) check_visit(ite) + def test_get_item(): t = relay.Tuple([relay.var('x', shape=())]) t = relay.TupleGetItem(t, 0) check_visit(t) + def test_ref_create(): r = relay.expr.RefCreate(relay.const(1.0)) check_visit(r) + def test_ref_read(): ref = relay.expr.RefCreate(relay.const(1.0)) r = relay.expr.RefRead(ref) check_visit(r) + def test_ref_write(): ref = relay.expr.RefCreate(relay.const(1.0)) r = relay.expr.RefWrite(ref, relay.const(2.0)) check_visit(r) + +def test_memo(): + expr = relay.const(1) + for _ in range(100): + expr = expr + expr + check_visit(expr) + + if __name__ == "__main__": test_constant() test_tuple() @@ -110,3 +133,4 @@ def test_ref_write(): test_ref_create() test_ref_read() test_ref_write() + test_memo() From bb87f044099ba61ba4782d17dd9127b869936373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Mon, 8 Apr 2019 22:56:57 -0700 Subject: [PATCH 3/7] add document (#2714) lint lint save save add more case save error lint lint commit do lint save fix lint wrap it back as func lint save remove dead comment fix style fix lint Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame address review feedback pe now handle freevar. as a result preserving function is now trivial. test add basic test, implement pretty printing for generic function test lint fix segfault save save do test fix another error address comment commit save address review feedback add test for invalidate, fix error in lookup rename cont to boduy fix error and add regression test fix error, add test case Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame fix lint remove extra line save save --- include/tvm/relay/expr.h | 4 +- include/tvm/relay/expr_functor.h | 1 + include/tvm/relay/pass.h | 30 +- include/tvm/relay/pattern_functor.h | 1 + python/tvm/relay/ir_pass.py | 17 + src/relay/backend/interpreter.cc | 2 +- src/relay/ir/expr_functor.cc | 3 - src/relay/ir/pretty_printer.cc | 70 +- src/relay/ir/type_functor.h | 1 + src/relay/pass/dead_code.cc | 153 ++-- src/relay/pass/partial_eval.cc | 805 ++++++++++++++++++ src/relay/pass/pass_util.h | 33 +- src/relay/pass/to_a_normal_form.cc | 11 +- src/relay/pass/util.cc | 53 +- .../relay/test_pass_dead_code_elimination.py | 16 +- tests/python/relay/test_pass_gradient.py | 2 + tests/python/relay/test_pass_partial_eval.py | 140 +++ ..._form.py => test_pass_to_a_normal_form.py} | 13 + ...m.py => test_pass_to_graph_normal_form.py} | 0 19 files changed, 1243 insertions(+), 112 deletions(-) create mode 100644 src/relay/pass/partial_eval.cc create mode 100644 tests/python/relay/test_pass_partial_eval.py rename tests/python/relay/{test_to_a_normal_form.py => test_pass_to_a_normal_form.py} (95%) rename tests/python/relay/{test_to_graph_normal_form.py => test_pass_to_graph_normal_form.py} (100%) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 47f5e9debae6..1d2fa5472993 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -570,8 +570,8 @@ inline const TTypeNode* ExprNode::type_as() const { * \return The text representation. */ std::string AsText(const NodeRef& node, - bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); + bool show_meta_data = true, + runtime::TypedPackedFunc annotate = nullptr); } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 9c29aebe3e7c..3b179f8e5330 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -89,6 +89,7 @@ class ExprFunctor { * \return The result of the call */ virtual R VisitExpr(const Expr& n, Args... args) { + CHECK(n.defined()); static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 2f8a50b9e65b..2db3a061b872 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -64,7 +64,7 @@ #include #include #include - +#include #include #include @@ -344,6 +344,17 @@ TVM_DLL bool WellFormed(const Expr& expr); */ TVM_DLL tvm::Array BoundVars(const Expr& expr); +/*! \brief Get all bound variables from pattern pat. + * + * Bound variables are all variables that got bound by the pat. + * They only have meaning inside that expr, and can only be used in it. + * + * \param pat the Pattern. + * + * \return List of bound vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array BoundVars(const Pattern& pat); + /*! \brief Get free type parameters from expression expr. * * Free variables are variables that are not bound by a @@ -431,12 +442,13 @@ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); /*! \brief Remove expressions which does not effect the program result. * - * It will remove let bindings which are not referenced, and branches that will - * not be entered. + * It will remove let bindings which are not referenced, + * and inline let bindings that are only used once. * - * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of - * the expression does not depend on a. Another example is `if (true) then 1 - * else 2` will be optimized into 1. + * For example, this pass should turn `let a = 1 in 2` into `2`, + * as the value of the expression does not depend on a. + * + * As another example, `let a = 1 in a` will be optimized into 1. * * \param e the expression to optimize. * @@ -558,6 +570,12 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); */ TVM_DLL Expr ToGraphNormalForm(const Expr& e); +/*! \brief Aggressive constant propagation/constant folding/inlining. + * It will do as much computation in compile time as possible. + * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). + * As a side effect, code size will explode. + */ +Expr PartialEval(const Expr& e); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 5c4020f11f0b..0ced3eaad2b8 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -89,6 +89,7 @@ class PatternFunctor { * \return The result of the call */ virtual R VisitPattern(const Pattern& n, Args... args) { + CHECK(n.defined()); static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index b3d323b2aed6..d2000263479d 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -956,3 +956,20 @@ def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True): A text representation of `ast`. """ return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf) + + +def partial_evaluate(expr): + """ + Evaluate the static fragment of the code. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + expr : tvm.relay.Expr + The output expression. + """ + return _ir_pass.partial_evaluate(expr) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 1d659d8922d9..735f1830d049 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -556,7 +556,7 @@ class Interpreter : CHECK_NE(cvn->constructor->tag, -1); if (op->constructor->tag == cvn->constructor->tag) { // todo(M.K.): should use ptr equality but it is broken - CHECK(op->patterns.size() == cvn->fields.size()); + CHECK_EQ(op->patterns.size(), cvn->fields.size()); for (size_t i = 0; i < op->patterns.size(); ++i) { if (!VisitPattern(op->patterns[i], cvn->fields[i])) { return false; diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 6df7efe51674..d0cd30adda29 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -43,9 +43,6 @@ Expr ExprMutator::VisitExpr(const Expr& expr) { } Expr ExprMutator::VisitExpr_(const VarNode* op) { - // NOTE: var will only be mutated once - // Thanks to the memo and reused during rewriting if necessary. - // It is safe to assume that the if (op->type_annotation.defined()) { auto type = this->VisitType(op->type_annotation); if (!op->type_annotation.same_as(type)) { diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 67cf9371bc27..969f08b32e83 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -245,15 +245,55 @@ class PrettyPrinter : return Doc(unique_prefix); } + Doc Print(Kind k) { + switch (k) { + case kType: + return Doc("Type"); + case kShapeVar: + return Doc("Shape"); + case kBaseType: + return Doc("BaseType"); + case kConstraint: + return Doc("Constraint"); + case kAdtHandle: + return Doc("AdtHandle"); + case kTypeData: + return Doc("TypeData"); + default: + LOG(ERROR) << "Unknown Kind"; + throw; + } + } /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ + * \brief Allocate name to a type variable. + * \param var The input type variable. + * \return The corresponding name. + */ + Doc AllocTypeVar(const TypeVar& var) { + std::string name = var->var->name_hint; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "t" + name; + } + Doc val = GetUniqueName("%" + name); + if (memo_type_.count(var)) { + val << "-malformed-ir"; + } + memo_type_[var] = val; + if (var->kind != kType) { + val << ": " << Print(var->kind); + } + return val; + } + + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ Doc AllocVar(const Var& var) { std::string name = var->name_hint(); // always make sure first name is alpha - if (name.length() != 0 && !std::isalpha(name[0])) { + if (name.length() == 0 || !std::isalpha(name[0])) { name = "v" + name; } Doc val = GetUniqueName("%" + name); @@ -387,12 +427,18 @@ class PrettyPrinter : } Doc PrintFunc(const Doc& prefix, const Function& fn) { - // TODO(tqchen, M.K.) support generic function - // Possibly through meta data - CHECK_EQ(fn->type_params.size(), 0U) - << "generic fn not yet supported"; Doc doc; - doc << prefix << "("; + doc << prefix; + if (fn->type_params.size() > 0) { + doc << "<"; + std::vector type_params; + for (const TypeVar& tv : fn->type_params) { + type_params.push_back(AllocTypeVar(tv)); + } + doc << PrintVec(type_params); + doc << ">"; + } + doc << "("; std::vector params; for (Var param : fn->params) { params.push_back(AllocVar(param)); @@ -516,6 +562,10 @@ class PrettyPrinter : return Print(GetRef(node), true); } + Doc VisitType_(const TypeVarNode* node) final { + return AllocTypeVar(GetRef(node)); + } + Doc VisitType_(const TensorTypeNode* node) final { // scalar type if (node->shape.size() == 0) { diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index acd868c0f9a8..e143fdac824d 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -77,6 +77,7 @@ class TypeFunctor { * \return The result of the call */ virtual R VisitType(const Type& n, Args... args) { + CHECK(n.defined()); static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index a6324560d2d3..06cd9091749b 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -35,90 +35,109 @@ namespace tvm { namespace relay { -bool IsBoolLit(const Expr& e, bool b) { - if (const ConstantNode* c = e.as()) { - if (c->is_scalar()) { - auto dt = c->tensor_type()->dtype; - if (dt == Bool()) { - return *reinterpret_cast(c->data->data) == b; - } else if (dt == UInt(8)) { - return *reinterpret_cast(c->data->data) == b; - } else if (dt == UInt(16)) { - return *reinterpret_cast(c->data->data) == b; - } else if (dt == UInt(32)) { - return *reinterpret_cast(c->data->data) == b; - } else if (dt == UInt(64)) { - return *reinterpret_cast(c->data->data) == b; - } else if (dt == Int(8)) { - return *reinterpret_cast(c->data->data) == b; - } else if (dt == Int(16)) { - return *reinterpret_cast(c->data->data) == b; - } else if (dt == Int(32)) { - return *reinterpret_cast(c->data->data) == b; - } else if (dt == Int(64)) { - return *reinterpret_cast(c->data->data) == b; - } - } - } - return false; -} - // calculate the dependency graph from expression -class CalcDep : private ExprMutator { +class CalcDep : private ExprVisitor { public: static Expr Eliminate(const Expr& e) { CalcDep cd; - auto res = cd(e); - GenLet gl(cd.var_map_); - gl(res); - return gl.lets_.Get(res); + cd.Calculate(e); + Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_); + return el(e); } private: - using VarMap = std::unordered_map; - VarMap var_map_; - - Expr VisitExpr_(const IfNode* i) final { - auto cond = VisitExpr(i->cond); - if (IsBoolLit(cond, true)) { - return Eliminate(i->true_branch); - } else if (IsBoolLit(cond, false)) { - return Eliminate(i->false_branch); - } else { - return IfNode::make(cond, Eliminate(i->true_branch), Eliminate(i->false_branch)); + template + using VarMap = std::unordered_map; + using VarSet = std::unordered_set; + VarMap expr_map_; + VarMap use_map_; + VarSet letrec_set_; + bool count_ = true; + VarSet dead_worklist_; + VarSet current_letrec_; + + void LetRec(const std::function& func, const Var& v) { + current_letrec_.insert(v); + func(); + current_letrec_.erase(v); + } + + void VisitExpr_(const LetNode* l) final { + if (count_) { + CHECK_EQ(expr_map_.count(l->var), 0); + CHECK_EQ(use_map_.count(l->var), 0); + expr_map_[l->var] = l->value; + use_map_[l->var] = 0; + dead_worklist_.insert(l->var); + LetRec([&]() { VisitExpr(l->value); }, l->var); } + VisitExpr(l->body); } - Expr VisitExpr_(const LetNode* l) final { - var_map_[l->var] = Eliminate(l->value); - return VisitExpr(l->body); + void VisitExpr(const Expr& e) final { + ExprFunctor::VisitExpr(e); } - Expr VisitExpr_(const FunctionNode* f) final { - return FunctionNode::make(f->params, - Eliminate(f->body), - f->ret_type, - f->type_params); + void VisitExpr_(const VarNode* v) final { + Var var = GetRef(v); + if (expr_map_.count(var) == 0) { + return; + } + if (current_letrec_.count(var) == 0) { + if (count_) { + use_map_[var] += 1; + dead_worklist_.erase(var); + } else { + CHECK_GT(use_map_[var], 0) << var; + use_map_[var] -= 1; + if (use_map_[var] == 0) { + dead_worklist_.insert(var); + } + } + } else { + letrec_set_.insert(var); + } + } + + void Calculate(const Expr& v) { + VisitExpr(v); + count_ = false; + while (!dead_worklist_.empty()) { + Var dead = *(dead_worklist_.begin()); + dead_worklist_.erase(dead); + CHECK_EQ(use_map_[dead], 0); + if (expr_map_.count(dead) > 0) { + LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead); + } + } } - // generate the let list from dependency graph - class GenLet : private ExprVisitor { + class Eliminator : private ExprMutator { private: - LetList lets_; - VarMap var_map_; - explicit GenLet(const VarMap& var_map) : var_map_(var_map) { } + VarMap expr_map_; + VarMap use_map_; + VarSet letrec_set_; + explicit Eliminator(const VarMap& expr_map, + const VarMap& use_map, + const VarSet& letrec_set) : + expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { } friend CalcDep; - void VisitExpr_(const VarNode* vnode) final { - Var v = GetRef(vnode); - auto it = var_map_.find(v); - if (it != var_map_.end()) { - Expr expr = it->second; - var_map_.erase(it); - // erase before visit to handle letrec - VisitExpr(expr); - // visit before push back so the dependency of dependency is before the dependency - lets_.Push(v, expr); + bool HasLet(const Var& v) { + return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); + } + + Expr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]); + } + + Expr VisitExpr_(const LetNode* op) final { + Var v = op->var; + if (HasLet(v)) { + return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); + } else { + return VisitExpr(op->body); } } }; diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc new file mode 100644 index 000000000000..f6283d380176 --- /dev/null +++ b/src/relay/pass/partial_eval.cc @@ -0,0 +1,805 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * + * \file partial_eval.cc + * + * \brief Perform known computation in compile time. + * + * The partial evaluator try to do computation at compile time, + * so it can generate code that do less work. + * Additionally, it might open more chance for further optimization, + * since the high level, structural part of the code (closure, reference, control flow) + * might get partially evaluated away, and the subsequent optimization (for example, kernel fusion) + * can reason across those structural code as it got removed. + * In the extreme case, partial evaluation can even turn the whole program + * into pure first order computation with no control flow. + * In such a case, we can compile the whole computation onto SIMD Instruction/GPU/FPGA, + * and get huge speedup. + * + * It works by making the following modifications to the standard relay interpreter: + * + * 0: The values become partially static value. + * Since we cannot know the value of every term at compile time, + * Term might get partially evaluated to 'Unknown Value'. + * Every partially static value is, hence, + * a static fragment that might not be there (partially static), + * and a dynamic fragment that is semantically equivalent to the original term, + * so the unknown part will be computed at runtime, using the dynamic fragment. + * + * 1: The interpreter holds a LetList, which preserves A Normal Form for the generated code. + * More specifically, we require that all dynamic is an atom. + * This avoids code duplication (which is both inefficient and incorrect), as atom has constant size + * and allow us to not handle capture-avoidance substitution (as atom has no binder). + * + * 2: The map of References to partially static values is reified, as described below. + * Instead of Reference having mutable field, Reference only has an unique identifier. + * There will be a mutable mapping of id to partially static value, called the store. + * This allow us to rollback the store: + * when a path may or may not be executed (as in a conditional), we copy the store, + * recurse with the copy, and reinstate the original when the call returns + * so that the effects of the computation are not preserved. + * We do this in if else, pattern matching, and in function, + * as, when we see a function, we partially evaluate it with all the argument as dynamic, + * to generate efficient dynamic for that function. + * + * 3: The generated code reuses bindings (although they are not shadowed), + * so we have to deduplicate them. + * + * 4: In the generated code, multiple VarNode might have same Id. + * While it is permitted, most pass use NodeHash for Var, + * and having multiple VarNode for same Id break them. + * Thus we remap them to a single Id for now. + * + * Also, It will also generate lots of dead code, + * so it is a good idea to feed it through the dead code eliminator after partial evaluation. + * + * The partial evaluator makes several assumptions, so there is room for improvement: + * + * 0: The partial evaluator treats global variables as opaque. + * Doing PartialEval on a module level will solve this. + * + * 1: The partial evaluator assume all functions as terminating. + * We need to has a max_expand parameter that shrink on every compile time evaluation, + * to make sure PE does not infinite loop. + * Additionally, we might add a termination analysis pass that lift this requirement + * for function that analysis found terminating. + * + * 2: Every time an unknown effect happened, we clear the whole store. + * It is too conservative: if a local reference is created (and do not get passed outside), + * An unknown global function call/global reference write can not modify it. + * We can pair PE with escape analysis/alias analysis. + * + * 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise. + * + * 4: When doing pattern matching, we can simplify the match even for dynamic case. + * Right now it is all or nothing: either a complete match, or the original dynamic code. + * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form. + * We then can reify the result. + * + * 5: Every time a function is called, it's code will get expanded and partially evaluated. + * We can do a binding time analysis to cache the result and avoid re-partial evaluation. + * + * These assumptions do not affect the correctness of the algorithm, however. + */ +#include +#include +#include +#include +#include "pass_util.h" +#include "let_list.h" + +namespace tvm { +namespace relay { + +using namespace runtime; + +/*! \brief Hash Var by it's id. + * Different VarNode might has same vid, and they are considered to be the same var in such case. + * Use VarHash to hash Var by id. + */ +struct VarHash { + size_t operator()(const Var& v) const { + return v->vid.hash(); + } +}; + +/*! \brief Compare Var by it's id. + * Different VarNode might has same vid, and they are considered to be the same var in such case. + * Use VarEqual to compare Var by id. + */ +struct VarEqual { + bool operator()(const Var& l, const Var& r) const { + return l->vid.get() == r->vid.get(); + } +}; + +/*! \brief The base container type of Relay values. */ +class StaticNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Value"; + TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); +}; + +class Static : public NodeRef { + public: + Static() {} + explicit Static(NodePtr n) : NodeRef(n) {} + const ValueNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = StaticNode; +}; + +struct PStaticNode : Node { + Static pstatic; // may be null + Expr dynamic; + PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { } + explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } + TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); +}; + +RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef); + +struct STupleNode : StaticNode { + std::vector fields; + explicit STupleNode(const std::vector& fields) : fields(fields) { } + TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); +}; + +RELAY_DEFINE_NODE_REF(STuple, STupleNode, Value); + +Static MkSTuple(const std::vector& fields) { + return Static(make_node(fields)); +} + +struct STensorNode : StaticNode { + runtime::NDArray data; + explicit STensorNode(const NDArray& data) : data(data) { } + TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); +}; + +RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value); + +Static MkSTensor(const NDArray& data) { + return Static(make_node(data)); +} + +struct SConstructorNode : StaticNode { + Constructor constructor; + std::vector fields; + SConstructorNode(const Constructor& constructor, const std::vector& fields) : + constructor(constructor), fields(fields) { } + TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode); +}; + +RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Value); + +Static MkSConstructor(const Constructor& constructor, const std::vector& fields) { + return Static(make_node(constructor, fields)); +} + +struct SRefNode : StaticNode { + // we will use the address as the guid for hashing + TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode); +}; + +RELAY_DEFINE_NODE_REF(SRef, SRefNode, Value); + +Static MkSRef() { + return Static(make_node()); +} + +using Func = std::function&, + const Attrs&, + const Array&, + LetList*)>; + +struct SFuncNode : StaticNode { + Func func; + explicit SFuncNode(const Func& func) : func(func) { } + TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode); +}; + +RELAY_DEFINE_NODE_REF(SFunc, SFuncNode, Value); + +Static MkSFunc(const Func& func) { + return Static(make_node(func)); +} + +/*! + * \brief A stack frame in the Relay interpreter. + * + * Contains a mapping from relay::Var to relay::Value. + */ +struct Frame { + /*! \brief The set of local variables and arguments for the frame. */ + std::unordered_map locals; + Frame() = default; +}; + +class Environment { + public: + Environment() : env_({Frame()}) { } + Environment(const Environment&) = delete; + + template + T Extend(const std::function& body) { + FrameContext fc(this); + return body(); + } + + void Insert(const Var& v, const PStatic& ps) { + CHECK(ps.defined()); + env_.back().locals[v] = ps; + } + + PStatic Lookup(const Var& v) { + auto rit = env_.rbegin(); + while (rit != env_.rend()) { + if (rit->locals.find(v) != rit->locals.end()) { + return rit->locals.find(v)->second; + } + ++rit; + } + LOG(FATAL) << "Unknown Variable: " << v; + throw; + } + + private: + std::list env_; + + struct FrameContext { + Environment* env_; + explicit FrameContext(Environment* env) : env_(env) { + env_->env_.push_back(Frame()); + } + ~FrameContext() { + env_->env_.pop_back(); + } + }; +}; + +/*! + * \brief As our store require rollback, we implement it as a frame. + * every time we need to copy the store, a new frame is insert. + * every time we roll back, a frame is popped. + */ +struct StoreFrame { + std::unordered_map store; + /*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */ + bool history_valid = true; + explicit StoreFrame(const std::unordered_map& store) : store(store) { } + StoreFrame() = default; +}; + +class Store { + public: + Store() : store_({StoreFrame()}) { } + Store(const Store&) = delete; + + template + T Extend(const std::function& body) { + StoreFrameContext sfc(this); + return body(); + } + + void Insert(const SRefNode* r, const PStatic& ps) { + store_.back().store[r] = ps; + } + + // return null if not found + PStatic Lookup(const SRefNode* r) { + auto rit = store_.rbegin(); + while (rit != store_.rend()) { + if (!rit->history_valid) { + return PStatic(); + } + if (rit->store.find(r) != rit->store.end()) { + return rit->store.find(r)->second; + } + ++rit; + } + return PStatic(); + } + + void Invalidate() { + store_.back().history_valid = false; + } + + private: + std::list store_; + + struct StoreFrameContext { + Store* store_; + explicit StoreFrameContext(Store* store) : store_(store) { + store_->store_.push_back(StoreFrame()); + } + ~StoreFrameContext() { + store_->store_.pop_back(); + } + }; +}; + +PStatic HasStatic(const Static& stat, const Expr& dynamic) { + return PStatic(make_node(stat, dynamic)); +} + +PStatic NoStatic(const Expr& dynamic) { + return PStatic(make_node(dynamic)); +} + +enum struct MatchStatus { + Match, NoMatch, Unknown +}; + +bool StatefulOp(const Expr& e) { + static auto op_stateful = Op::GetAttr("TOpIsStateful"); + struct StatefulOpVisitor : ExprVisitor { + bool stateful = false; + void VisitExpr_(const OpNode* op) { + stateful = stateful || op_stateful.get(GetRef(op), false); + } + }; + StatefulOpVisitor sov; + sov(e); + return sov.stateful; +} + +using FInterpreter = runtime::TypedPackedFunc; + +DLContext CPUContext() { + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + return ctx; +} + +FInterpreter CPUInterpreter() { + Target target = Target::create("llvm"); + // use a fresh build context + // in case we are already in a build context. + BuildConfigContext fresh_build_ctx(build_config()); + + return CreateInterpreter(Module(nullptr), CPUContext(), target); +} + +class PartialEvaluator : public ExprFunctor, + public PatternFunctor { + public: + PartialEvaluator(const tvm::Array& free_vars) { + for (const Var& v : free_vars) { + env_.Insert(v, NoStatic(v)); + } + } + + PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { + return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef(op))); + } + + PStatic VisitExpr_(const TupleNode* op, LetList* ll) final { + std::vector value; + tvm::Array expr; + for (const Expr& e : op->fields) { + PStatic ps = VisitExpr(e, ll); + value.push_back(ps); + expr.push_back(ps->dynamic); + } + return HasStatic(MkSTuple(value), ll->Push(TupleNode::make(expr))); + } + + PStatic VisitExpr_(const TupleGetItemNode* op, LetList* ll) final { + PStatic ps = VisitExpr(op->tuple, ll); + if (ps->pstatic.defined()) { + return Downcast(ps->pstatic)->fields[op->index]; + } else { + return NoStatic(ll->Push(TupleGetItemNode::make(ps->dynamic, op->index))); + } + } + + PStatic VisitExpr_(const VarNode* op, LetList* ll) final { + return env_.Lookup(GetRef(op)); + } + + PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { + return NoStatic(GetRef(op)); + } + + PStatic VisitExpr_(const LetNode* op, LetList* ll) final { + env_.Insert(op->var, VisitExpr(op->value, ll)); + return VisitExpr(op->body, ll); + } + + PStatic VisitExpr_(const IfNode* op, LetList* ll) final { + PStatic c = VisitExpr(op->cond, ll); + if (c->pstatic.defined()) { + NDArray cpu_array = Downcast(c->pstatic)->data.CopyTo(CPUContext()); + CHECK_EQ(TVMType2Type(cpu_array->dtype), Bool()); + if (reinterpret_cast(cpu_array->data)[0]) { + return VisitExpr(op->true_branch, ll); + } else { + return VisitExpr(op->false_branch, ll); + } + } else { + Expr t = store_.Extend([&]() { + return LetList::With([&](LetList* ll) { + return VisitExpr(op->true_branch, ll)->dynamic; + }); + }); + Expr f = store_.Extend([&]() { + return LetList::With([&](LetList* ll) { + return VisitExpr(op->false_branch, ll)->dynamic; + }); + }); + store_.Invalidate(); + return NoStatic(ll->Push(IfNode::make(c->dynamic, t, f))); + } + } + + PStatic VisitExpr_(const RefCreateNode* op, LetList* ll) final { + PStatic ps = VisitExpr(op->value, ll); + Static r = MkSRef(); + store_.Insert(r.as(), ps); + return HasStatic(r, ll->Push(RefCreateNode::make(ps->dynamic))); + } + + PStatic VisitExpr_(const RefWriteNode* op, LetList* ll) final { + PStatic r = VisitExpr(op->ref, ll); + PStatic v = VisitExpr(op->value, ll); + if (r->pstatic.defined()) { + store_.Insert(r->pstatic.as(), v); + } else { + store_.Invalidate(); + } + return HasStatic(MkSTuple({}), ll->Push(RefWriteNode::make(r->dynamic, v->dynamic))); + } + + PStatic VisitExpr_(const RefReadNode* op, LetList* ll) final { + PStatic r = VisitExpr(op->ref, ll); + if (r->pstatic.defined()) { + PStatic ret = store_.Lookup(r->pstatic.as()); + if (ret) { + return ret; + } + } + return NoStatic(ll->Push(RefReadNode::make(r->dynamic))); + } + + PStatic VisitExpr_(const CallNode* op, LetList* ll) final { + PStatic f = VisitExpr(op->op, ll); + std::vector x; + tvm::Array x_dyn; + for (const Expr& e : op->args) { + PStatic ps = VisitExpr(e, ll); + x.push_back(ps); + x_dyn.push_back(ps->dynamic); + } + if (f->pstatic.defined()) { + return Downcast(f->pstatic)->func(x, op->attrs, op->type_args, ll); + } else { + store_.Invalidate(); + return NoStatic(ll->Push(CallNode::make(f->dynamic, x_dyn, op->attrs, op->type_args))); + } + } + + PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { + Function func = GetRef(op); + if (func->IsPrimitive()) { + return HasStatic(MkSFunc(ConstEvaluateFunc(func, ll)), func); + } + std::vector > free_vars; + for (const auto& v : FreeVars(GetRef(op))) { + free_vars.push_back(std::pair(v, env_.Lookup(v))); + } + Func f = [=](const std::vector& pv, + const Attrs& attrs, + const tvm::Array& type_args, + LetList* ll) { + return env_.Extend([&]() { + CHECK_EQ(pv.size(), func->params.size()); + for (size_t i = 0; i < pv.size(); ++i) { + env_.Insert(func->params[i], pv[i]); + } + for (const auto& p : free_vars) { + env_.Insert(p.first, p.second); + } + tvm::Map subst; + for (size_t i = 0; i < type_args.size(); ++i) { + subst.Set(func->type_params[i], type_args[i]); + } + for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { + subst.Set(func->type_params[i], Type()); + } + return VisitExpr(TypeSubst(func->body, subst), ll); + }); + }; + Expr dyn = store_.Extend([&]() { + store_.Invalidate(); + return FunctionNode::make(func->params, LetList::With([&](LetList* ll) { + std::vector pv; + for (const auto& v : func->params) { + pv.push_back(NoStatic(v)); + } + tvm::Array type_args; + for (const auto& tp : func->type_params) { + type_args.push_back(tp); + } + return f(pv, Attrs(), type_args, ll)->dynamic; + }), func->ret_type, func->type_params, func->attrs); + }); + return HasStatic(MkSFunc(f), ll->Push(dyn)); + } + + Expr Reflect(const PStatic& st) { + if (const STensorNode* op = st->pstatic.as()) { + return ConstantNode::make(op->data); + } else if (const STupleNode* op = st->pstatic.as()) { + tvm::Array fields; + for (const PStatic& field : op->fields) { + fields.push_back(Reflect(field)); + } + return TupleNode::make(fields); + } else { + LOG(FATAL) << "Unknown case"; + throw; + } + } + + PStatic Reify(const Value& v, LetList* ll) const { + if (const TensorValueNode* op = v.as()) { + return HasStatic(MkSTensor(op->data), ll->Push(ConstantNode::make(op->data))); + } else if (const TupleValueNode* op = v.as()) { + std::vector fields; + tvm::Array fields_dyn; + for (const Value& field : op->fields) { + PStatic ps = Reify(field, ll); + fields.push_back(ps); + fields_dyn.push_back(ps->dynamic); + } + return HasStatic(MkSTuple(fields), ll->Push(TupleNode::make(fields_dyn))); + } else { + LOG(FATAL) << "Unknown case"; + throw; + } + } + + // Constant evaluate a expression. + PStatic ConstEvaluate(const Expr& expr, LetList* ll) { + Expr infered = InferType(expr, Module(nullptr)); + Expr fused = FuseOps(infered, 0); + Expr fused_infered = InferType(fused, Module(nullptr)); + return Reify(executor_(fused_infered), ll); + } + + Func ConstEvaluateFunc(const Expr& expr, LetList* ll) { + return [=](const std::vector& pv, + const Attrs& attrs, + const tvm::Array& type_args, + LetList* ll) { + tvm::Array ns_args; + for (const PStatic& ps : pv) { + ns_args.push_back(ps->dynamic); + } + PStatic ns = NoStatic(CallNode::make(expr, ns_args, attrs, type_args)); + if (StatefulOp(expr)) { + return ns; + } + tvm::Array args; + for (const PStatic& ps : pv) { + if (ps->pstatic.defined()) { + args.push_back(Reflect(ps)); + } else { + return ns; + } + } + return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll); + }; + } + + PStatic VisitExpr_(const OpNode* op, LetList* ll) final { + return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef(op), ll)), GetRef(op)); + } + + PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final { + Constructor c = GetRef(op); + Func f = [=](const std::vector& pv, + const Attrs& attrs, + const tvm::Array& type_args, + LetList* ll) { + tvm::Array dyn; + for (const PStatic& ps : pv) { + dyn.push_back(ps->dynamic); + } + return HasStatic(MkSConstructor(c, pv), ll->Push(CallNode::make(c, dyn))); + }; + return HasStatic(MkSFunc(f), GetRef(op)); + } + + PStatic VisitExpr_(const MatchNode* op, LetList* ll) final { + PStatic ps = VisitExpr(op->data, ll); + return env_.Extend([&]() { + for (const Clause& c : op->clauses) { + switch (VisitPattern(c->lhs, ps)) { + case MatchStatus::Match: + return VisitExpr(c->rhs, ll); + case MatchStatus::NoMatch: + continue; + case MatchStatus::Unknown: + tvm::Array clauses; + for (const Clause& c : op->clauses) { + Expr expr = store_.Extend([&]() { + return LetList::With([&](LetList* ll) { + for (const Var& v : BoundVars(c->lhs)) { + env_.Insert(v, NoStatic(v)); + } + return VisitExpr(c->rhs, ll)->dynamic; + }); + }); + clauses.push_back(ClauseNode::make(c->lhs, expr)); + } + store_.Invalidate(); + return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses))); + } + } + LOG(FATAL) << "No case Match"; + throw; + }); + } + + MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final { + return MatchStatus::Match; + } + + MatchStatus VisitPattern_(const PatternVarNode* op, const PStatic& ps) final { + env_.Insert(op->var, ps); + return MatchStatus::Match; + } + + MatchStatus VisitPattern_(const PatternConstructorNode* op, const PStatic& ps) final { + if (ps->pstatic.defined()) { + SConstructor scn = Downcast(ps->pstatic); + CHECK_NE(op->constructor->tag, -1); + CHECK_NE(scn->constructor->tag, -1); + if (op->constructor->tag == scn->constructor->tag) { + // todo(M.K.): should use ptr equality but it is broken + CHECK_EQ(op->patterns.size(), scn->fields.size()); + MatchStatus current_match_status = MatchStatus::Match; + for (size_t i = 0; i < op->patterns.size(); ++i) { + MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]); + switch (ms) { + case MatchStatus::Match: + continue; + case MatchStatus::NoMatch: + return MatchStatus::NoMatch; + case MatchStatus::Unknown: + current_match_status = MatchStatus::Unknown; + } + } + return current_match_status; + } + return MatchStatus::NoMatch; + } else { + return MatchStatus::Unknown; + } + } + + private: + Environment env_; + Store store_; + DLContext context_ = CPUContext(); + FInterpreter executor_ = CPUInterpreter(); +}; + +Var DeDupVar(const Var& v) { + return VarNode::make(v->name_hint(), v->type_annotation); +} + +TypeVar DeDupTypeVar(const TypeVar& tv) { + return TypeVarNode::make(tv->var->name_hint, tv->kind); +} + +/*! \brief Use a fresh Id for every Var to make the result well-formed. */ +Expr DeDup(const Expr& e) { + class DeDupMutator : public ExprMutator, public PatternMutator { + public: + Var Fresh(const Var& v) { + Var ret = DeDupVar(v); + rename_[v] = ret; + return ret; + } + + Expr VisitExpr(const Expr& e) final { + return ExprMutator::VisitExpr(e); + } + + Expr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + return rename_.count(v) != 0 ? rename_.at(v) : v; + } + + Expr VisitExpr_(const LetNode* op) final { + return LetNode::make(Fresh(op->var), VisitExpr(op->value), VisitExpr(op->body)); + } + + Expr VisitExpr_(const FunctionNode* op) final { + tvm::Array params; + for (const Var& param : op->params) { + params.push_back(Fresh(param)); + } + return FunctionNode::make(params, + VisitExpr(op->body), + op->ret_type, + op->type_params, + op->attrs); + } + + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Var VisitVar(const Var& v) final { + return Fresh(v); + } + + private: + std::unordered_map rename_; + }; + return DeDupMutator().VisitExpr(e); +} + +/*! \brief Remap multiple Var sharing the same Id into the same Var. */ +Expr Remap(const Expr& e) { + class RemapMutator : public ExprMutator, public PatternMutator { + Expr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + if (remap_.count(v) == 0) { + remap_.insert({v, v}); + } + return remap_.at(v); + } + + Var VisitVar(const Var& v) final { + return Downcast(VisitExpr(v)); + } + + private: + std::unordered_map remap_; + }; + return RemapMutator().VisitExpr(e); +} + +Expr PartialEval(const Expr& e) { + return TransformF([&](const Expr& e) { + return LetList::With([&](LetList* ll) { + PartialEvaluator pe(FreeVars(e)); + return Remap(DeDup(pe.VisitExpr(e, ll)->dynamic)); + }); + }, e); +} + +TVM_REGISTER_API("relay._ir_pass.partial_evaluate") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PartialEval(args[0]); + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h index df3294151ff0..38d8b0bd9040 100644 --- a/src/relay/pass/pass_util.h +++ b/src/relay/pass/pass_util.h @@ -42,7 +42,6 @@ namespace relay { std::unordered_map GetExprRefCount(const Expr& body); - /*! * \brief Check if expr is positive constant. * \param expr The expression to be checked. @@ -50,7 +49,6 @@ GetExprRefCount(const Expr& body); */ bool IsAllPositiveConstant(const Expr& expr); - /*! * \brief Substitute var with subst. * \param type The type to be substituted. @@ -60,6 +58,15 @@ bool IsAllPositiveConstant(const Expr& expr); */ Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst); +/*! + * \brief Substitute var with subst. + * \param expr The expr to be substituted. + * \param tvar The type variable to be substituted. + * \param subst The target of substitution. + * \return The substituted result. + */ +Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst); + /*! * \brief Substitute type vars in type. * \param type The type to be substituted. @@ -68,6 +75,28 @@ Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst); */ Type TypeSubst(const Type& type, const tvm::Map& subst_map); +/*! + * \brief Substitute type vars in type. + * \param expr The expr to be substituted. + * \param subst_map The map of substitution. + * \return The substituted result. + */ +Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map); + +/*! + * \brief Make arbitrary transformation preserve the out most function. + * \param func The transformation. + * \param e The expression + * \return the transformed expression. If e is a function the return is also a function. + */ +inline Expr TransformF(const std::function& func, const Expr& e) { + if (const FunctionNode* f = e.as()) { + return FunctionNode::make(f->params, func(f->body), f->ret_type, f->type_params, f->attrs); + } else { + return func(e); + } +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PASS_UTIL_H_ diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 60fd6dbbaf1a..bac6fd28faf5 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -28,6 +28,7 @@ #include #include "let_list.h" #include "../../common/arena.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -481,15 +482,7 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set* gv) { } Expr ToANormalForm(const Expr& e, const Module& m, std::set* gv) { - if (const auto* f = e.as()) { - return FunctionNode::make(f->params, - ToANormalFormAux(f->body, m, gv), - f->ret_type, - f->type_params, - f->attrs); - } else { - return ToANormalFormAux(e, m, gv); - } + return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e); } Expr ToANormalForm(const Expr& e, const Module& m) { diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 3080b5d72e30..fa655a785338 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -27,6 +27,7 @@ #include #include #include +#include "pass_util.h" #include "../ir/type_functor.h" namespace tvm { @@ -171,8 +172,7 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { return ret; } - Array Bound(const Expr& expr) { - this->VisitExpr(expr); + Array Collect() { Array ret; for (const auto& v : bound_vars_.data) { ret.push_back(v); @@ -180,6 +180,16 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { return ret; } + Array Bound(const Expr& expr) { + this->VisitExpr(expr); + return Collect(); + } + + Array Bound(const Pattern& pat) { + this->VisitPattern(pat); + return Collect(); + } + Array All(const Expr& expr) { this->VisitExpr(expr); Array ret; @@ -256,6 +266,10 @@ tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } +tvm::Array BoundVars(const Pattern& pat) { + return VarVisitor().Bound(pat); +} + tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } @@ -267,7 +281,12 @@ TVM_REGISTER_API("relay._ir_pass.free_vars") TVM_REGISTER_API("relay._ir_pass.bound_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = BoundVars(args[0]); + NodeRef x = args[0]; + if (x.as_derived()) { + *ret = BoundVars(Downcast(x)); + } else { + *ret = BoundVars(Downcast(x)); + } }); TVM_REGISTER_API("relay._ir_pass.all_vars") @@ -388,5 +407,33 @@ bool IsAllPositiveConstant(const Expr& expr) { } } +Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) { + return TypeSubst(type, tvm::Map({{tvar, subst}})); +} + +Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) { + return TypeSubst(expr, tvm::Map({{tvar, subst}})); +} + +Type TypeSubst(const Type& type, const tvm::Map& subst_map) { + return Bind(type, subst_map); +} + +Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { + class TypeSubstMutator : public ExprMutator, public PatternMutator { + public: + explicit TypeSubstMutator(const tvm::Map& subst_map) : subst_map_(subst_map) { } + Type VisitType(const Type& t) final { + return TypeSubst(t, subst_map_); + } + Var VisitVar(const Var& v) final { + return Downcast(VisitExpr(v)); + } + private: + const tvm::Map& subst_map_; + }; + return TypeSubstMutator(subst_map).VisitExpr(expr); +} + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 9014a47f2ef3..963d490eaf50 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -48,8 +48,13 @@ def test_let(): def test_used_let(): + orig = relay.Let(e.c, e.one, e.c + e.c) + assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c)) + + +def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) - assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c)) + assert alpha_equal(dead_code_elimination(orig), e.d) def test_chain_unused_let(): @@ -87,13 +92,6 @@ def test_op_let(): assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two)) -def test_if(): - cond = relay.const(True) - orig = relay.If(cond, e.a, e.b) - y = dead_code_elimination(orig) - assert alpha_equal(y, e.a) - - def test_tuple_get_item(): t = relay.Var('t') g = relay.TupleGetItem(t, 0) @@ -102,9 +100,9 @@ def test_tuple_get_item(): if __name__ == "__main__": - test_if() test_let() test_used_let() + test_inline() test_chain_unused_let() test_recursion() test_op_let() diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index d6489a8f3377..f5968a41f028 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -22,9 +22,11 @@ import numpy as np + def rand(dtype='float32', *shape): return tvm.nd.array(np.random.rand(*shape).astype(dtype)) + def test_id(): shape = (10, 10) dtype = 'float32' diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py new file mode 100644 index 000000000000..a00cebd244b3 --- /dev/null +++ b/tests/python/relay/test_pass_partial_eval.py @@ -0,0 +1,140 @@ +import numpy as np +import tvm +from tvm import relay +from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination +from tvm.relay.ir_pass import gradient, alpha_equal, infer_type +from tvm.relay import op, create_executor +from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue +from tvm.relay.prelude import Prelude +from tvm.relay import create_executor + + +def check_eval(expr, expected_result, mod=None, rtol=1e-07): + ctx = tvm.context("llvm", 0) + intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + + result = intrp.evaluate(expr) + np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) + + +def dcpe(expr): + return dead_code_elimination(partial_evaluate(expr)) + + +def test_tuple(): + t = relay.TypeVar("t") + x = relay.Var("x", t) + body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) + f = relay.Function([x], body, None, [t]) + assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t])) + + +def test_const_inline(): + d = relay.Var("d") + double = relay.Function([d], d + d) + orig = double(relay.const(4.0)) + assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0)) + + +def test_ref(): + d = relay.Var("d") + r = relay.Var("r") + x = relay.Var("x") + body = relay.RefRead(r) + body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body) + body = relay.Let(r, relay.RefCreate(d), body) + square = relay.Function([d], body) + assert alpha_equal(dcpe(square), relay.Function([d], d * d)) + + +def test_ad(): + shape = (10, 10) + dtype = "float32" + t = relay.TensorType(shape, dtype) + d = relay.Var("d", t) + f = relay.Function([d], d * d) + g = dcpe(gradient(f)) + m = d * d + o = relay.op.ones_like(m) + grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d) + expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])])) + assert alpha_equal(g, expected) + + +def test_if_ref(): + shape = () + dtype = "bool" + t = relay.TensorType(shape, dtype) + d = relay.Var("d", t) + r = relay.Var("r") + update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r))) + u = relay.Var("u") + body = relay.If(d, u(), u()) + eff = relay.Var("eff") + body = relay.Let(eff, body, relay.RefRead(r)) + f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body))) + f = infer_type(f) + pe_f = infer_type(partial_evaluate(f)) + ex = create_executor() + f_res = ex.evaluate(f)(relay.const(True)) + pe_f_res = ex.evaluate(pe_f)(relay.const(True)) + np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy())) + np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy())) + + +def test_function_invalidate(): + shape = () + dtype = "bool" + t = relay.TensorType(shape, dtype) + d = relay.Var("d", t) + r = relay.Var("r") + fetch = relay.Function([], relay.RefRead(r)) + fet = relay.Var("fetch") + fet_obscured = relay.Var("fetch_obscured") + u = relay.Var("u") + body = relay.If(d, fet_obscured(), fet_obscured()) + body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body) + body = relay.Let(fet_obscured, relay.If(d, fet, fet), body) + body = relay.Let(fet, fetch, body) + body = relay.Let(r, relay.RefCreate(relay.const(0)), body) + f = relay.Function([d], body) + f = infer_type(f) + pe_f = infer_type(partial_evaluate(f)) + ex = create_executor() + f_res = ex.evaluate(f)(relay.const(True)) + pe_f_res = ex.evaluate(pe_f)(relay.const(True)) + np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy())) + np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy())) + + +def test_head_cons(): + mod = relay.Module() + p = Prelude(mod) + def hd_impl(): + a = relay.TypeVar("a") + x = relay.Var("x", p.l(a)) + y = relay.Var("y") + z = relay.Var("z") + cons_case = relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternVar(y), + relay.PatternVar(z)]), + y) + return relay.Function([x], relay.Match(x, [cons_case]), a, [a]) + t = relay.TypeVar("t") + x = relay.Var("x", t) + hd = relay.Var("hd") + body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil()))) + f = relay.Function([x], body, None, [t]) + f = infer_type(f, mod=mod) + res = dcpe(f) + assert alpha_equal(res, relay.Function([x], x, t, [t])) + + +if __name__ == '__main__': + test_tuple() + test_const_inline() + test_ref() + test_ad() + test_if_ref() + test_function_invalidate() + test_head_cons() diff --git a/tests/python/relay/test_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py similarity index 95% rename from tests/python/relay/test_to_a_normal_form.py rename to tests/python/relay/test_pass_to_a_normal_form.py index 9e1e8728f131..2e95dbe55121 100644 --- a/tests/python/relay/test_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -154,6 +154,7 @@ def test_add(): assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert "let" in mod[add].astext() + def test_let(): x = relay.Var("x") y = relay.Var("y") @@ -163,6 +164,17 @@ def test_let(): check_eval(body, 8) check_eval(to_a_normal_form(body), 8) + +def test_function(): + x = relay.Var("x") + f = relay.Function([x], x + x) + d = relay.const(4.0, 'float32') + anf_f = to_a_normal_form(f) + assert isinstance(anf_f, relay.Function) + check_eval(f(d), 8) + check_eval(anf_f(d), 8) + + if __name__ == '__main__': test_explicit_bound() test_order() @@ -171,3 +183,4 @@ def test_let(): test_ref() test_add() test_let() + test_function() diff --git a/tests/python/relay/test_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py similarity index 100% rename from tests/python/relay/test_to_graph_normal_form.py rename to tests/python/relay/test_pass_to_graph_normal_form.py From 06347782a36e07094443b22a80b74e302f599d0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Tue, 9 Apr 2019 11:49:16 -0700 Subject: [PATCH 4/7] Update let_list.h (#2987) --- src/relay/pass/let_list.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 293f88979886..bd36a15c843c 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -81,7 +81,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Expr expr) { - return Push(IncompleteTypeNode::make(Kind::kType), expr); + return Push(Type(), expr); } /*! From 9cf622b4a0c2e6bee1dd879f1cdfbb5b4a69f12f Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Tue, 9 Apr 2019 21:59:16 -0700 Subject: [PATCH 5/7] Expose backtrace symbols in Debug mode (#3001) --- CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9317ee3aad52..884c44e4c6df 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,7 +84,8 @@ else(MSVC) include(CheckCXXCompilerFlag) check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - add_compile_options(-O0 -Wall -fPIC -fvisibility=hidden -std=c++11) + set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS} -rdynamic") + set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS} -rdynamic") else() set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11 ${CMAKE_CXX_FLAGS}") From d712e0d73de1e64c811cc5f33a3b25075e33bd12 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 9 Apr 2019 22:11:21 -0700 Subject: [PATCH 6/7] add output format to ndk build func (#2999) --- python/tvm/contrib/ndk.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/contrib/ndk.py b/python/tvm/contrib/ndk.py index 931fa03a7308..e1703ce03f8e 100644 --- a/python/tvm/contrib/ndk.py +++ b/python/tvm/contrib/ndk.py @@ -63,3 +63,6 @@ def create_shared(output, msg = "Compilation error:\n" msg += py_str(out) raise RuntimeError(msg) + +# assign output format +create_shared.output_format = "so" From 57f47a17f266e4123be49b84b5caf6a143d2544a Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 9 Apr 2019 22:12:50 -0700 Subject: [PATCH 7/7] fix java checkstyle version (#2998) --- jvm/core/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm/core/pom.xml b/jvm/core/pom.xml index 88e272814ba1..849c86b10549 100644 --- a/jvm/core/pom.xml +++ b/jvm/core/pom.xml @@ -77,7 +77,7 @@ under the License. com.puppycrawl.tools checkstyle - [8.18,) + 8.18