diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 8eda7dd824ca..c65bb41282cf 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -66,6 +66,15 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod); */ TVM_DLL bool ConstantCheck(const Expr& e); +/*! + * \brief Check whether an expression is in the basic block normal form. + * + * \param e the expression. + * + * \return whether the expression is in the basic block normal form. + */ +TVM_DLL bool BasicBlockNormalFormCheck(const Expr& e); + /*! * \brief Check that each Var is only bound once. * diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index d995301c1688..cf14febb02c1 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -116,6 +116,21 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1); */ TVM_DLL Pass RewriteAnnotatedOps(int fallback_device); +/*! + * \brief Turn an expression to Basic Block Normal Form. + * + * We define a block as a group of expressions implied by the scope structure. + * + * Each graph node can only belong to a single block. + * + * For any value that is being used in multiple blocks, it has to be referred + * by a Var which is defined in a block, whose scope is the least common ancestor + * of blocks this value is used. + * + * \return The pass. + */ +TVM_DLL Pass ToBasicBlockNormalForm(); + /*! * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). * diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 632af460ce96..165e39a09c1b 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -106,6 +106,21 @@ def check_constant(expr): """ return _ffi_api.check_constant(expr) +def check_basic_block_normal_form(expr): + """Check whether an expression is in the basic block form + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + Returns + ------- + result : bool + Whether the expression is in the basic block form. + """ + return _ffi_api.check_basic_block_normal_form(expr) + def free_vars(expr): """Get free Vars from expression expr in Post DFS order. diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7db068785ba6..3abc3822f0ef 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -488,6 +488,21 @@ def ToANormalForm(): """ return _ffi_api.ToANormalForm() +def ToBasicBlockNormalForm(): + """Turn an expression to Basic Block Normal Form. + We define a block as a group of expressions implied by the scope structure. + Each graph node can only belong to a single block. + For any value that is being used in multiple blocks, it has to be referred + by a Var which is defined in a block, whose scope is the least common ancestor + of blocks this value is used. + + Returns + ------- + ret: tvm.transform.Pass + The registered pass that transforms an expression into Basic Block Normal Form. + """ + return _ffi_api.ToBasicBlockNormalForm() + def ToCPS(expr, mod=None): """ diff --git a/src/relay/analysis/dependency_graph.cc b/src/relay/analysis/dependency_graph.cc index 5db833866a3e..de61800d8c52 100644 --- a/src/relay/analysis/dependency_graph.cc +++ b/src/relay/analysis/dependency_graph.cc @@ -137,6 +137,9 @@ class DependencyGraph::Creator : private ExprFunctor { DependencyGraph::Node* n = graph_.expr_node[GetRef(f)]; DependencyGraph::Node* b = NewNode(true); Depend(n, b); + for (const auto& p : f->params) { + Depend(b, p); + } Depend(b, f->body); graph_.post_dfs_order.push_back(b); } @@ -145,6 +148,7 @@ class DependencyGraph::Creator : private ExprFunctor { DependencyGraph::Node* n = graph_.expr_node[GetRef(l)]; DependencyGraph::Node* b = NewNode(true); Depend(n, b); + Depend(b, l->var); Depend(b, l->value); Depend(b, l->body); graph_.post_dfs_order.push_back(b); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 4d84c480a5c6..533619ec8a19 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -253,6 +253,7 @@ class RelayBuildModule : public runtime::ModuleNode { Array pass_seqs; Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); + pass_seqs.push_back(transform::ToBasicBlockNormalForm()); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b811911b4053..a98f1ef9d0ab 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -927,6 +927,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe Array pass_seqs; Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); + pass_seqs.push_back(transform::ToBasicBlockNormalForm()); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index c0e0b3a23864..c925dc0922a4 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -107,6 +107,12 @@ class LetList { return ret; } + /*! \brief get the number of let bindings in the let list. + * + * \return the let list size. + */ + size_t size() const { return lets_.size(); } + /*! \brief generate an LetList and wrap the result automatically. * * \param f a function that generate the unwrapped Expr. diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 5f5876212b62..50d0fbb5f17b 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -31,6 +31,11 @@ #include #include +#include +#include + +#include "../analysis/dependency_graph.h" +#include "let_list.h" namespace tvm { namespace relay { @@ -184,6 +189,89 @@ struct TreeBranchNode : TreeNode { ~TreeBranchNode() {} }; +struct ScopeNode; +using Scope = std::shared_ptr; +using NodeScopeMap = std::unordered_map; +using ExprSet = std::unordered_set; + +/* Invariant: when parent is null level is 0 + * Invariant: when parent is not null level is 1 + parent->level + */ +struct ScopeNode { + // the level of the scope + size_t level; + // the parent scope + Scope parent; + // the corresponding let list which holds all let bindings in the scope + std::shared_ptr let_list = std::make_shared(); + explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {} + ScopeNode() : level(0) {} +}; + +/*! \brief Calculate the scope of nodes in the dependency graph by least common ancestor. + * + * \param dg the input dependency graph + * \param expr_scope the output node -> scope mapping for all nodes. + * \param lifted_exprs the output set of expressions whose scope is lifted due to dependency + */ +std::pair CalcScope(const DependencyGraph& dg); + +/*! \brief find the least common ancestor of lhs scope and rhs scope. + */ +Scope LCA(Scope lhs, Scope rhs); + +/* Special care is needed to handle local recursion. + * Fill additionally take a (possibly null) Var argument, + * If it is not null, Fill is required to bind the transformed result to that var. + */ +class Fill : ExprFunctor { + public: + static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope); + + // For basic block normal form, bind expressions only if the original expression's + // scope should be lifted + static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, + NodeScopeMap* node_scope, ExprSet* lifted); + + private: + const DependencyGraph& dg_; + NodeScopeMap* node_scope_ = nullptr; + std::unordered_map memo; + // a set of Expressions to include for let bindings. If set to nullptr + // all Exprs will be pushed to the let list. + ExprSet* include_set_ = nullptr; + + Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set) + : dg_(dg), node_scope_(node_scope), include_set_(include_set) {} + + Scope GetScope(const Expr& e); + Scope GetSubScope(const Expr& e, size_t i); + + Expr VisitExpr(const Expr& e, const Var& v) final; + Expr VisitExpr(const Expr& e); + + Expr Atomic(const Expr& e, const Var& v); + // Bind expression `now` to var `v` if the original expression is in the include set, or if + // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly. + Expr Compound(const Expr& orig, const Expr& now, const Var& v); + + Expr VisitExpr_(const CallNode* c, const Var& v) final; + Expr VisitExpr_(const TupleNode* t, const Var& v) final; + Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final; + Expr VisitExpr_(const RefCreateNode* r, const Var& v) final; + Expr VisitExpr_(const RefReadNode* r, const Var& v) final; + Expr VisitExpr_(const RefWriteNode* r, const Var& v) final; + Expr VisitExpr_(const IfNode* i, const Var& v) final; + Expr VisitExpr_(const FunctionNode* f, const Var& v) final; + Expr VisitExpr_(const LetNode* l, const Var& v) final; + Expr VisitExpr_(const ConstantNode* c, const Var& v) final; + Expr VisitExpr_(const VarNode* vn, const Var& v) final; + Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final; + Expr VisitExpr_(const OpNode* op, const Var& v) final; + Expr VisitExpr_(const ConstructorNode* c, const Var& v) final; + Expr VisitExpr_(const MatchNode* m, const Var& v) final; +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_TRANSFORMS_PASS_UTIL_H_ diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 8d1024217a1e..06e0d56e1919 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -36,23 +36,6 @@ namespace tvm { namespace relay { -struct ScopeNode; -using Scope = std::shared_ptr; - -/* Invariant: when parent is null level is 0 - * - * Invariant: when parent is not null level is 1 + parent->level - */ -struct ScopeNode { - size_t level; - Scope parent; - std::shared_ptr ll = std::make_shared(); - explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {} - ScopeNode() : level(0) {} -}; - -Scope ChildScope(const Scope& s) { return std::make_shared(s); } - Scope LCA(Scope lhs, Scope rhs) { while (lhs != rhs) { if (lhs->level > rhs->level) { @@ -67,10 +50,16 @@ Scope LCA(Scope lhs, Scope rhs) { return lhs; } -std::unordered_map CalcScope(const DependencyGraph& dg) { - std::unordered_map expr_scope; +std::pair CalcScope(const DependencyGraph& dg) { + NodeScopeMap expr_scope; + ExprSet lifted_exprs; + std::unordered_map node_to_expr; + for (auto expr_node : dg.expr_node) { + node_to_expr[expr_node.second] = expr_node.first; + } bool global_scope_used = false; Scope global_scope = std::make_shared(); + for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { DependencyGraph::Node* n = *it; auto iit = n->parents.head; @@ -81,171 +70,187 @@ std::unordered_map CalcScope(const DependencyGrap global_scope_used = true; } else { s = expr_scope.at(iit->value); + const auto original_s = s; iit = iit->next; for (; iit != nullptr; iit = iit->next) { s = LCA(s, expr_scope.at(iit->value)); } + if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) { + // filter out exprs whose scope do not matter + Expr expr = node_to_expr[n]; + if (!expr.as()) { + lifted_exprs.insert(expr); + } + } + } + if (n->new_scope) { + auto child_scope = std::make_shared(s); + expr_scope.insert({n, child_scope}); + } else { + expr_scope.insert({n, s}); } - expr_scope.insert({n, n->new_scope ? ChildScope(s) : s}); } CHECK(global_scope_used); - return expr_scope; + return std::make_pair(expr_scope, lifted_exprs); } -/* Special care is needed to handle local recursion. - * Fill additionally take a (possibly null) Var argument, - * If it is not null, Fill is required to bind the transformed result to that var. - */ -class Fill : ExprFunctor { - public: - static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, - std::unordered_map* node_scope) { - Fill fi(dg, node_scope); - return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); - } - - private: - const DependencyGraph& dg_; - std::unordered_map* node_scope_; - std::unordered_map memo; +Expr Fill::ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) { + Fill fi(dg, node_scope, nullptr); + return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); +} - Fill(const DependencyGraph& dg, std::unordered_map* node_scope) - : dg_(dg), node_scope_(node_scope) {} +// For basic block normal form, bind expressions only if the original expression's scope +// should be lifted +Expr Fill::ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, + NodeScopeMap* node_scope, ExprSet* lifted) { + Fill fi(dg, node_scope, lifted); + auto var = fi.VisitExpr(e); + return fi.GetScope(e)->let_list->Get(var); +} - Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } +Scope Fill::GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } - Scope GetSubScope(const Expr& e, size_t i) { - DependencyGraph::Node* n = dg_.expr_node.at(e); - auto h = n->children.head; - while (i != 0) { - CHECK(h); - --i; - h = h->next; - } +Scope Fill::GetSubScope(const Expr& e, size_t i) { + DependencyGraph::Node* n = dg_.expr_node.at(e); + auto h = n->children.head; + while (i != 0) { CHECK(h); - return node_scope_->at(h->value); + --i; + h = h->next; } + CHECK(h); + return node_scope_->at(h->value); +} - Expr VisitExpr(const Expr& e, const Var& v) final { - if (memo.count(e) == 0) { - memo.insert({e, ExprFunctor::VisitExpr(e, v)}); - } else if (v.defined()) { - GetScope(e)->ll->Push(v, memo.at(e)); - } - auto ret = memo.at(e); - CHECK(IsAtomic(ret)); - return ret; +Expr Fill::VisitExpr(const Expr& e, const Var& v) { + if (memo.count(e) == 0) { + memo.insert({e, ExprFunctor::VisitExpr(e, v)}); + } else if (v.defined()) { + GetScope(e)->let_list->Push(v, memo.at(e)); } + auto ret = memo.at(e); + // if no include_set is specified, every expression should be atomic. + if (include_set_ == nullptr) CHECK(IsAtomic(ret)); + return ret; +} - Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } +Expr Fill::VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } - Expr Atomic(const Expr& e, const Var& v) { return v.defined() ? GetScope(e)->ll->Push(v, e) : e; } +Expr Fill::Atomic(const Expr& e, const Var& v) { + return v.defined() ? GetScope(e)->let_list->Push(v, e) : e; +} - Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Var var = v.defined() ? v : Var(String("x"), Type()); - return GetScope(orig)->ll->Push(var, now); +// Bind expression `now` to var `v` if the original expression is in the include set, or if +// v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly +Expr Fill::Compound(const Expr& orig, const Expr& now, const Var& v) { + Var var = v.defined() ? v : Var(String("x"), Type()); + bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); + if (!v.defined() && not_included) { + return now; + } else { + return GetScope(orig)->let_list->Push(var, now); } +} - Expr VisitExpr_(const CallNode* c, const Var& v) final { - Expr e = GetRef(c); - std::vector args; - for (const auto& a : c->args) { - args.push_back(VisitExpr(a)); - } - return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); +Expr Fill::VisitExpr_(const CallNode* c, const Var& v) { + Expr e = GetRef(c); + std::vector args; + for (const auto& a : c->args) { + args.push_back(VisitExpr(a)); } + return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); +} - Expr VisitExpr_(const TupleNode* t, const Var& v) final { - Expr e = GetRef(t); - std::vector fields; - for (const auto& a : t->fields) { - fields.push_back(VisitExpr(a)); - } - return Compound(e, Tuple(fields), v); +Expr Fill::VisitExpr_(const TupleNode* t, const Var& v) { + Expr e = GetRef(t); + std::vector fields; + for (const auto& a : t->fields) { + fields.push_back(VisitExpr(a)); } + return Compound(e, Tuple(fields), v); +} - Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { - Expr e = GetRef(t); - return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v); - } +Expr Fill::VisitExpr_(const TupleGetItemNode* t, const Var& v) { + Expr e = GetRef(t); + return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v); +} - Expr VisitExpr_(const RefCreateNode* r, const Var& v) final { - Expr e = GetRef(r); - return Compound(e, RefCreate(VisitExpr(r->value)), v); - } +Expr Fill::VisitExpr_(const RefCreateNode* r, const Var& v) { + Expr e = GetRef(r); + return Compound(e, RefCreate(VisitExpr(r->value)), v); +} - Expr VisitExpr_(const RefReadNode* r, const Var& v) final { - Expr e = GetRef(r); - return Compound(e, RefRead(VisitExpr(r->ref)), v); - } +Expr Fill::VisitExpr_(const RefReadNode* r, const Var& v) { + Expr e = GetRef(r); + return Compound(e, RefRead(VisitExpr(r->ref)), v); +} - Expr VisitExpr_(const RefWriteNode* r, const Var& v) final { - Expr e = GetRef(r); - return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v); - } +Expr Fill::VisitExpr_(const RefWriteNode* r, const Var& v) { + Expr e = GetRef(r); + return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v); +} - Expr VisitExpr_(const IfNode* i, const Var& v) final { - Expr e = GetRef(i); - Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), - GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); - return Compound(e, ret, v); - } +Expr Fill::VisitExpr_(const IfNode* i, const Var& v) { + Expr e = GetRef(i); + Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->let_list->Get(VisitExpr(i->true_branch)), + GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch))); + return Compound(e, ret, v); +} - Expr VisitExpr_(const FunctionNode* f, const Var& v) final { - Expr e = GetRef(f); - Expr ret; - if (f->HasNonzeroAttr(attr::kPrimitive)) { - ret = e; - } else { - ret = Function(f->params, GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), f->ret_type, - f->type_params, f->attrs); - } - return Compound(e, ret, v); +Expr Fill::VisitExpr_(const FunctionNode* f, const Var& v) { + Expr e = GetRef(f); + Expr ret; + if (f->HasNonzeroAttr(attr::kPrimitive)) { + ret = e; + } else { + ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, + f->type_params, f->attrs); } + return Compound(e, ret, v); +} - Expr VisitExpr_(const LetNode* l, const Var& v) final { - Expr e = GetRef(l); - VisitExpr(l->value, l->var); - Expr ret = GetSubScope(e, 0)->ll->Get(VisitExpr(l->body)); - return Compound(e, ret, v); - } +Expr Fill::VisitExpr_(const LetNode* l, const Var& v) { + Expr e = GetRef(l); + VisitExpr(l->value, l->var); + Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body)); + return Compound(e, ret, v); +} - Expr VisitExpr_(const ConstantNode* c, const Var& v) final { - Expr e = GetRef(c); - return Compound(e, e, v); - } +Expr Fill::VisitExpr_(const ConstantNode* c, const Var& v) { + Expr e = GetRef(c); + return Compound(e, e, v); +} - Expr VisitExpr_(const VarNode* vn, const Var& v) final { - Expr e = GetRef(vn); - return Atomic(e, v); - } +Expr Fill::VisitExpr_(const VarNode* vn, const Var& v) { + Expr e = GetRef(vn); + return Atomic(e, v); +} - Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { - GlobalVar gv = GetRef(gvn); - return Atomic(gv, v); - } +Expr Fill::VisitExpr_(const GlobalVarNode* gvn, const Var& v) { + GlobalVar gv = GetRef(gvn); + return Atomic(gv, v); +} - Expr VisitExpr_(const OpNode* op, const Var& v) final { - Expr e = GetRef(op); - return Atomic(e, v); - } +Expr Fill::VisitExpr_(const OpNode* op, const Var& v) { + Expr e = GetRef(op); + return Atomic(e, v); +} - Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { - Expr e = GetRef(c); - return Atomic(e, v); - } +Expr Fill::VisitExpr_(const ConstructorNode* c, const Var& v) { + Expr e = GetRef(c); + return Atomic(e, v); +} - Expr VisitExpr_(const MatchNode* m, const Var& v) final { - Expr e = GetRef(m); - Expr data = VisitExpr(m->data); - std::vector clauses; - for (const Clause& c : m->clauses) { - clauses.push_back( - Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); - } - return Compound(e, Match(data, clauses, m->complete), v); +Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) { + Expr e = GetRef(m); + Expr data = VisitExpr(m->data); + std::vector clauses; + for (const Clause& c : m->clauses) { + clauses.push_back( + Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->let_list->Get(VisitExpr(c->rhs)))); } -}; + return Compound(e, Match(data, clauses, m->complete), v); +} Expr ToANormalFormAux(const Expr& e) { /* When you lift a lambda, what is inside is also being lift. @@ -269,8 +274,8 @@ Expr ToANormalFormAux(const Expr& e) { * Every scope additionally contain a LetList which collect all value of that scope. * We do an additional pass to fill all the LetList and we are done. */ - std::unordered_map node_scope = CalcScope(dg); - return Fill::ToANormalForm(e, dg, &node_scope); + std::pair scopes = CalcScope(dg); + return Fill::ToANormalForm(e, dg, &scopes.first); } IRModule ToANormalForm(const IRModule& m) { diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc new file mode 100644 index 000000000000..5fc01e151760 --- /dev/null +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file to_basic_block_normal_form.cc + * + * \brief Turn an expression to the basic normal form. + */ +#include +#include +#include +#include + +#include "../../support/arena.h" +#include "../analysis/dependency_graph.h" +#include "let_list.h" +#include "pass_util.h" + +namespace tvm { +namespace relay { + +Expr ToBasicBlockNormalFormAux(const Expr& e) { + // calculate all the dependency between nodes. + support::Arena arena; + DependencyGraph dg = DependencyGraph::Create(&arena, e); + /* The scope of the whole expr is global. + * The scope of any subexpr, is the lowest common ancestor of all incoming edge. + * We also record the set of expressions whose scope is lifted. + */ + std::pair scopes = CalcScope(dg); + return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second); +} + +IRModule ToBasicBlockNormalForm(const IRModule& mod) { + DLOG(INFO) << "ToBBlock:" << std::endl << mod; + + tvm::Map updates; + auto funcs = mod->functions; + for (const auto& it : funcs) { + CHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables"; + if (const auto* n = it.second.as()) { + if (n->GetAttr(attr::kCompiler).defined()) continue; + } + Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second); + updates.Set(it.first, Downcast(ret)); + } + + for (auto pair : updates) { + mod->Add(pair.first, pair.second, true); + } + + DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod; + + return mod; +} + +bool BasicBlockNormalFormCheck(const Expr& e) { + // calculate all the dependency between nodes. + support::Arena arena; + DependencyGraph dg = DependencyGraph::Create(&arena, e); + std::pair scopes = CalcScope(dg); + for (auto expr : scopes.second) { + LOG(FATAL) << "The expression below violates the basic block normal form in that " + << "its scope should be lifted:\n" + << expr; + } + return scopes.second.size() == 0; +} + +TVM_REGISTER_GLOBAL("relay.analysis.check_basic_block_normal_form") + .set_body_typed(BasicBlockNormalFormCheck); + +namespace transform { + +Pass ToBasicBlockNormalForm() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relay::ToBasicBlockNormalForm(m); }; + return CreateModulePass(pass_func, 1, "ToBasicBlockNormalForm", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.ToBasicBlockNormalForm") + .set_body_typed(ToBasicBlockNormalForm); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_analysis_basic_block_normal_form.py b/tests/python/relay/test_analysis_basic_block_normal_form.py new file mode 100644 index 000000000000..dfd7dd1f4118 --- /dev/null +++ b/tests/python/relay/test_analysis_basic_block_normal_form.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm import relay +from tvm.relay.analysis import check_basic_block_normal_form + +def test_one_block(): + x = relay.var('x') + y = relay.add(x, x) + z = relay.add(x, y) + check_basic_block_normal_form(z) + +def test_let(): + x = relay.var('x') + y = relay.var('y') + body = relay.Let(y, x, y) + check_basic_block_normal_form(body) + +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_invalid_if(): + cond = relay.var('cond', dtype='bool', shape=()) + shared = relay.var('shared') + true_branch = shared + false_branch = relay.add(shared, shared) + body = relay.If(cond, true_branch, false_branch) + """ + The program below violates basic block normal form, as the scope of %shared + is ambiguous and should not be in that of true branch. + + free_var %cond: bool + if (%cond) { + free_var %shared + %shared + } else { + add(%shared, %shared) + } + """ + check_basic_block_normal_form(body) + +def test_valid_if(): + cond = relay.var('cond', dtype='bool', shape=()) + shared = relay.var('shared') + true_branch = shared + false_branch = relay.add(shared, shared) + body = relay.If(cond, true_branch, false_branch) + shared_bound = relay.var('shared_bound', shape=(1,), dtype='float32') + body = relay.Let(shared, shared_bound, body) + """ + The program below uses let binding to control the scope of %shared, which + follows the basic block normal form. + + free_var %shared_bound: Tensor[(1), float32] + let %shared = %shared_bound; + free_var %cond: bool + if (%cond) { + %shared + } else { + add(%shared, %shared) + } + """ + check_basic_block_normal_form(body) + +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_invalid_if2(): + """ + fn (%x: float32) { + %0 = equal(%x, 2f); + if (%0) { + %1 = add(%x, 1f); + multiply(%1, 2f) + } else { + multiply(%1, 1f) + } + } + """ + x = relay.var('x', shape=(), dtype='float32') + one = relay.const(1, dtype='float32') + two = relay.const(2, dtype='float32') + v1 = relay.add(x, one) + v2 = relay.equal(x, two) + true_branch = relay.multiply(v1, two) + false_branch = relay.multiply(v1, one) + body = relay.If(v2, true_branch, false_branch) + func = relay.Function([x], body) + check_basic_block_normal_form(func) + +def test_valid_if2(): + """ + fn (%x: float32) { + let %v1 = add(%x, 1f); + %0 = equal(%x, 2f); + if (%0) { + multiply(%v1, 2f) + } else { + multiply(%v1, 1f) + } + } + """ + x = relay.var('x', shape=(), dtype='float32') + one = relay.const(1, dtype='float32') + two = relay.const(2, dtype='float32') + v1 = relay.var('v1') + v2 = relay.equal(x, two) + true_branch = relay.multiply(v1, two) + false_branch = relay.multiply(v1, one) + body = relay.If(v2, true_branch, false_branch) + body = relay.Let(v1, relay.add(x, one), body) + func = relay.Function([x], body) + check_basic_block_normal_form(func) + +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_func(): + x = relay.var('x', shape=(1,), dtype='float32')#, a) + y = relay.var('y', shape=(1,), dtype='float32')#, a) + z = relay.var('z', shape=(1,), dtype='float32')#, a) + x2 = relay.add(x, x) + func_a = relay.Function([y], relay.add(x2, y)) #, a, [a]) + func_b = relay.Function([z], relay.add(x2, z)) #, a, [a]) + body = relay.Tuple([func_a, func_b]) + body = relay.Function([x], body) + """ + fn (%x: Tensor[(1), float32]) { + %1 = fn (%y: Tensor[(1), float32]) { + %0 = add(%x, %x); + add(%0, %y) + }; + %2 = fn (%z: Tensor[(1), float32]) { + add(%0, %z) + }; + (%1, %2) + } + """ + check_basic_block_normal_form(body) + +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_higher_order_return(): + x = relay.var('x', shape=(1,), dtype='float32')#, a) + y = relay.var('y', shape=(1,), dtype='float32')#, a) + z = relay.var('z', shape=(1,), dtype='float32')#, a) + x2 = relay.add(x, x) + func_a = relay.Function([y], relay.add(x2, y)) #, a, [a]) + func_b = relay.Function([z], relay.add(x2, z)) #, a, [a]) + body = relay.Tuple([func_a, func_b]) + body = relay.Function([x], body) + """ + fn (%x: Tensor[(1), float32]) { + %1 = fn (%y: Tensor[(1), float32]) { + %0 = add(%x, %x); + add(%0, %y) + }; + %2 = fn (%z: Tensor[(1), float32]) { + add(%0, %z) + }; + (%1, %2) + } + """ + check_basic_block_normal_form(body) + + +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_higher_order_nested(): + x = relay.var('x', dtype='float32', shape=(1,)) + s = relay.var('s', dtype='float32', shape=(1,)) + shared = relay.add(s, s) + func_true = relay.Function([x], relay.add(x, shared)) + choice_t = relay.FuncType([], relay.scalar_type('bool')) + f = relay.Var('f', choice_t) + z = relay.Var('z') + body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared))) + top = relay.Function([f, s], body) + """ + fn (%f: fn () -> bool, %s: Tensor[(1), float32]) { + %0 = %f(); + if (%0) { + fn (%x: Tensor[(1), float32]) { + %1 = add(%s, %s); + add(%x, %1) + } + } else { + fn (%z) { + add(%z, %1) + } + } + } + """ + check_basic_block_normal_form(top) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_to_basic_block_normal_form.py b/tests/python/relay/test_pass_to_basic_block_normal_form.py new file mode 100644 index 000000000000..05c6544503c1 --- /dev/null +++ b/tests/python/relay/test_pass_to_basic_block_normal_form.py @@ -0,0 +1,482 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np +import tvm +from tvm import te +from tvm import relay +from tvm.relay.analysis import detect_feature +from tvm.relay import op, create_executor, transform +from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions, count +from tvm.relay.analysis import Feature +from tvm.relay.analysis import check_basic_block_normal_form + + +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = tvm.IRModule.from_expr(expr) + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def check_eval(expr, expected_result, mod=None, rtol=1e-07): + ctx = tvm.context("llvm", 0) + intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + + result = intrp.evaluate(expr) + np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) + + +def test_no_explicit_bind(): + x = relay.const(1) + y = op.add(x, x) + z = op.add(y, y) + f = relay.Function([], op.add(z, z)) + """ + fn () { + %0 = add(1, 1); + %1 = add(%0, %0); + add(%1, %1) + } + """ + assert not Feature.fLet in detect_feature(f) + bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm()) + assert Feature.fLet not in detect_feature(bblock) + check_eval(f(), 8.0) + check_eval(bblock(), 8.0) + check_basic_block_normal_form(bblock) + +def test_top_level_nested_if(): + x = relay.var('x', shape=(), dtype='bool') + y = relay.var('y', shape=(), dtype='float32') + z = relay.var('z', shape=(), dtype='float32') + cond_t = relay.const(True) + cond_f = relay.const(False) + one = relay.const(1, dtype='float32') + three = relay.const(3, dtype='float32') + y2 = relay.add(y, y) + z2 = relay.add(z, z) + true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2)) + false_branch = relay.If(cond_f, z2, one) + body = relay.If(x, true_branch, false_branch) + """ + free_var %x: bool + if (%x) { + if (True) { + free_var %z: float32 + %0 = add(%z, %z); + free_var %y: float32 + %1 = add(%y, %y); + add(%0, %1) + } else { + add(3f, %1) + } + } else { + if (False) { + %0 + } else { + 1f + } + } + """ + def expected(): + x = relay.var('x', shape=(), dtype='bool') + y = relay.var('y', shape=(), dtype='float32') + z = relay.var('z', shape=(), dtype='float32') + cond_t = relay.const(True) + cond_f = relay.const(False) + one = relay.const(1, dtype='float32') + three = relay.const(3, dtype='float32') + y2 = relay.var('y2') + z2 = relay.var('z2') + true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2)) + true_branch = relay.Let(y2, relay.add(y, y), true_branch) + false_branch = relay.If(cond_f, z2, one) + body = relay.If(x, true_branch, false_branch) + body = relay.Let(z2, relay.add(z, z), body) + return body + + bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()]) + """ + free_var %z: float32 + let %x: float32 = add(%z, %z) /* ty=float32 */; + free_var %x1: bool + if (%x1) { + free_var %y: float32 + let %x2: float32 = add(%y, %y) /* ty=float32 */; + if (True /* ty=bool */) { + add(%x, %x2) /* ty=float32 */ + } else { + add(3f /* ty=float32 */, %x2) /* ty=float32 */ + } + } else { + if (False /* ty=bool */) { + %x + } else { + 1f /* ty=float32 */ + } + } + """ + expected_output = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True) + +def test_nested_if(): + x = relay.var('x', shape=(), dtype='bool') + y = relay.var('y', shape=(), dtype='float32') + cond_t = relay.const(True) + cond_f = relay.const(False) + one = relay.const(1, dtype='float32') + two = relay.const(2, dtype='float32') + three = relay.const(3, dtype='float32') + y2 = relay.add(y, y) + true_branch = relay.If(cond_t, y2, relay.add(three, y2)) + false_branch = relay.If(cond_f, two, one) + body = relay.If(x, true_branch, false_branch) + """ + free_var %x: bool + if (%x) { + if (True) { + free_var %y: float32 + %0 = add(%y, %y); + %0 + } else { + add(3f, %0) + } + } else { + if (False) { + 2f + } else { + 1f + } + } + """ + def expected(): + x = relay.var('x', shape=(), dtype='bool') + y = relay.var('y', shape=(), dtype='float32') + cond_t = relay.const(True) + cond_f = relay.const(False) + one = relay.const(1, dtype='float32') + two = relay.const(2, dtype='float32') + three = relay.const(3, dtype='float32') + y2 = relay.var('y2') + true_branch = relay.If(cond_t, y2, relay.add(three, y2)) + true_branch = relay.Let(y2, relay.add(y, y), true_branch) + false_branch = relay.If(cond_f, two, one) + body = relay.If(x, true_branch, false_branch) + return body + + bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()]) + """ + free_var %x: bool + if (%x) { + free_var %y: float32 + let %x1: float32 = add(%y, %y) /* ty=float32 */; + if (True /* ty=bool */) { + %x1 + } else { + add(3f /* ty=float32 */, %x1) /* ty=float32 */ + } + } else { + if (False /* ty=bool */) { + 2f /* ty=float32 */ + } else { + 1f /* ty=float32 */ + } + } + """ + expected_output = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True) + check_basic_block_normal_form(bblock) + + +# make sure we do not infinite loop. +# it is too large so we won't check for the exact program. +def test_recursion(): + """ + Program: + let f(n: i32) -> i32 = { + m = (n * 2) + if (n == 0) { + return m; + } else { + return m + f(n - 1); + } + } + f(5); + """ + mod = tvm.IRModule() + i64 = relay.TensorType((), 'int64') + f = relay.GlobalVar("f") + n = relay.Var("n", i64) + m = n * relay.const(2, 'int64') + cond = relay.equal(n, relay.const(0, 'int64')) + false_branch = m + f(n - relay.const(1, 'int64')) + funcbody = relay.If(cond, m, false_branch) + value = relay.Function([n], funcbody, i64, []) + mod[f] = value + check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) + old_f = mod[f] + mod = transform.ToBasicBlockNormalForm()(mod) + f = mod[f] + check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) + check_basic_block_normal_form(f) + +def test_ref(): + i = relay.Var('i') + iv = relay.Var('iv') + u = relay.Var('u') + uv = relay.Var('uv') + body = relay.add(iv, uv) + body = relay.Let(uv, relay.RefRead(i), body) + body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body) + body = relay.Let(iv, relay.RefRead(i), body) + body = relay.Let(i, relay.RefCreate(relay.const(1)), body) + check_eval(body, 3) + opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm()) + check_eval(opt_body, 3) + check_basic_block_normal_form(opt_body) + + +def test_nat_add(): + mod = tvm.IRModule() + p = Prelude(mod) + add_nat_definitions(p) + nat = p.nat + add = p.add + s = p.s + z = p.z + ctx = tvm.context("llvm", 0) + intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) + assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 + expr = add(s(z()), s(z())) + f = relay.GlobalVar("f") + mod[f] = relay.Function([], expr) + mod = transform.ToBasicBlockNormalForm()(mod) + opt_expr = mod["f"] + assert count(p, intrp.evaluate(opt_expr.body)) == 2 + assert not Feature.fLet in detect_feature(mod[add]) + check_basic_block_normal_form(opt_expr) + +def test_let(): + def test_let1(): + x = relay.Var("x") + c = relay.const(4.0, 'float32') + body = relay.Let(x, c, x) + body = run_opt_pass(body, transform.InferType()) + """ + let %x: float32 = 4f /* ty=float32 */; + %x + """ + opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm()) + assert tvm.ir.structural_equal(body, opt_body) + check_basic_block_normal_form(opt_body) + + def test_let1_1(): + x = relay.Var("y") + d = relay.const(4.0, 'float32') + body = relay.Let(x, d, relay.add(x,x)) + body = run_opt_pass(body, transform.InferType()) + opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm()) + assert tvm.ir.structural_equal(body, opt_body) + check_basic_block_normal_form(opt_body) + + def test_let2(): + x = relay.Var("x") + y = relay.Var("y") + d = relay.const(4.0, 'float32') + body = relay.Let(y, x, x) + body = relay.Let(x, d, body) + body = run_opt_pass(body, transform.InferType()) + check_eval(body, 4) + + def expected(): + x = relay.Var("x") + y = relay.Var("y") + d = relay.const(4.0, 'float32') + body = relay.Let(y, x, y) + body = relay.Let(x, d, body) + return body + + opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm()) + expected_body = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(opt_body, expected_body) + check_basic_block_normal_form(opt_body) + + def test_let3(): + x = relay.Var("x") + y = relay.Var("y") + z = relay.Var("z") + c = relay.const(3.0, 'float32') + d = relay.const(4.0, 'float32') + body = relay.Let(z, x + y, x + z) + body = relay.Let(x, d, body) + body = relay.Let(y, c, body) + body = run_opt_pass(body, transform.InferType()) + opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm()) + assert tvm.ir.structural_equal(body, opt_body) + check_basic_block_normal_form(opt_body) + + test_let1() + test_let1_1() + test_let2() + test_let3() + +def test_function(): + t = relay.TensorType((), 'float32') + x = relay.Var("x", t) + f = relay.Function([x], x + x) + d = relay.const(4.0, 'float32') + bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm()) + assert isinstance(bblock, relay.Function) + check_eval(f(d), 8) + check_eval(bblock(d), 8) + check_basic_block_normal_form(bblock) + +def test_gradient_if(): + x = relay.var("a", shape=(1, 16)) + y = relay.var("y", shape=(1, 16)) + cond = relay.var("cond", shape=(), dtype='uint1') + net = relay.If(cond, x, x) + net = relay.add(x, net) + net = relay.Function([cond,x,y], net) + mod = tvm.IRModule.from_expr(net) + mod = relay.transform.ToBasicBlockNormalForm()(mod) + net_grad = relay.transform.gradient(mod["main"], mode='higher_order') + mod["main"] = net_grad + mod_grad = relay.transform.ToBasicBlockNormalForm()(mod) + check_basic_block_normal_form(mod_grad['main']) + check_basic_block_normal_form(mod['main']) + +def test_if(): + def if_expr(x): + """ + free_var %x: float32 + %0 = equal(%x, 2f); + if (%0) { + %1 = add(%x, 1f); + multiply(%1, 2f) + } else { + multiply(%1, 1f) + } + """ + one = relay.const(1, dtype='float32') + two = relay.const(2, dtype='float32') + v1 = relay.add(x, one) + v2 = relay.equal(x, two) + true_branch = relay.multiply(v1, two) + false_branch = relay.multiply(v1, one) + body = relay.If(v2, true_branch, false_branch) + return body + + def expected_if_expr(x): + """ + free_var %x: float32 + let %v1: float32 = add(%x, 1f /* ty=float32 */) /* ty=float32 */; + %0 = equal(%x, 2f /* ty=float32 */) /* ty=bool */; + if (%0) { + multiply(%v1, 2f /* ty=float32 */) /* ty=float32 */ + } else { + multiply(%v1, 1f /* ty=float32 */) /* ty=float32 */ + } + """ + one = relay.const(1, dtype='float32') + two = relay.const(2, dtype='float32') + v1 = relay.var('v1') + v2 = relay.equal(x, two) + true_branch = relay.multiply(v1, two) + false_branch = relay.multiply(v1, one) + body = relay.If(v2, true_branch, false_branch) + body = relay.Let(v1, relay.add(x, one), body) + return body + + x = relay.var('x', shape=(), dtype='float32') + body = if_expr(x) + expected_body = expected_if_expr(x) + bblock = run_opt_pass(body, transform.ToBasicBlockNormalForm()) + expected_bblock = run_opt_pass(expected_body, transform.InferType()) + assert tvm.ir.structural_equal(bblock, expected_bblock, map_free_vars=True) + check_basic_block_normal_form(bblock) + + func = relay.Function([x], body) + expected_func = relay.Function([x], expected_body) + bblock = run_opt_pass(func, transform.ToBasicBlockNormalForm()) + expected_bblock = run_opt_pass(expected_func, transform.InferType()) + assert tvm.ir.structural_equal(bblock, expected_bblock) + check_basic_block_normal_form(bblock) + +def test_higher_order_return(): + x = relay.var('x', shape=(1,), dtype='float32')#, a) + y = relay.var('y', shape=(1,), dtype='float32')#, a) + z = relay.var('z', shape=(1,), dtype='float32')#, a) + x2 = relay.add(x, x) + func_a = relay.Function([y], relay.add(x2, y)) #, a, [a]) + func_b = relay.Function([z], relay.add(x2, z)) #, a, [a]) + body = relay.Tuple([func_a, func_b]) + body = relay.Function([x], body) + """ + fn (%x: Tensor[(1), float32]) { + %1 = fn (%y: Tensor[(1), float32]) { + %0 = add(%x, %x); + add(%0, %y) + }; + %2 = fn (%z: Tensor[(1), float32]) { + add(%0, %z) + }; + (%1, %2) + } + """ + + bblock = run_opt_pass(body, transform.ToBasicBlockNormalForm()) + check_basic_block_normal_form(bblock) + + +def test_higher_order_nested(): + x = relay.var('x', dtype='float32', shape=(1,)) + s = relay.var('s', dtype='float32', shape=(1,)) + shared = relay.add(s, s) + func_true = relay.Function([x], relay.add(x, shared)) + choice_t = relay.FuncType([], relay.scalar_type('bool')) + f = relay.Var('f', choice_t) + z = relay.Var('z') + body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared))) + top = relay.Function([f, s], body) + """ + fn (%f: fn () -> bool, %s: Tensor[(1), float32]) { + %0 = %f(); + if (%0) { + fn (%x: Tensor[(1), float32]) { + %1 = add(%s, %s); + add(%x, %1) + } + } else { + fn (%z) { + add(%z, %1) + } + } + } + """ + + bblock = run_opt_pass(top, transform.ToBasicBlockNormalForm()) + check_basic_block_normal_form(bblock) + +if __name__ == '__main__': + pytest.main([__file__])