diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index aa9d3b41554c..2d6cdeaa8ca1 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -62,21 +62,31 @@ TVM_DLL Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, int opt_level, String name, tvm::Array required); -/*! \brief Remove expressions which does not effect the program result. +/*! \brief Remove let-bound expressions which do not effect the program result. * - * It will remove let bindings which are not referenced, - * and inline let bindings that are only used once. + * This pass will remove let bindings which are not referenced. If inline_once is True, + * let bindings which are only referenced once will also be inlined. * - * For example, this pass should turn `let a = 1 in 2` into `2`, + * For example, this pass should turn `let a = 1; 2` into `2`, * as the value of the expression does not depend on a. * - * As another example, `let a = 1 in a` will be optimized into 1. + * As another example, `let a = 1; a` will be optimized into 1 if inline_once is True. * - * \param inline_once whether or not to inline binding used one. + * If ignore_purity is False, possibly side-effecting expressions (such as memory allocation, + * random number generation, reading/writing references, or calls to primitive or external + * functions) are never elided or inlined. This is sound, but ignore_purity can be set to True + * to suppress this check. + * + * The analysis is fairly conservative, for example it assumes all local functions + * may be called more than once, any functions passed as arguments have side effects, + * and so on. + * + * \param inline_once whether or not to inline bindings used exactly once. + * \param ignore_purity whether to ignore whether expressions have side-effects * * \return the pass. */ -TVM_DLL Pass DeadCodeElimination(bool inline_once = false); +TVM_DLL Pass DeadCodeElimination(bool inline_once = false, bool ignore_purity = false); /*! * \brief Convert all expressions of TensorType into GradCell, diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 01473a82fb3a..06b462f4e41f 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -209,20 +209,22 @@ def CanonicalizeOps(): return _ffi_api.CanonicalizeOps() -def DeadCodeElimination(inline_once=False): +def DeadCodeElimination(inline_once=False, ignore_impurity=False): """Remove expressions that do not have any users (dead code). Parameters ---------- inline_once: Optional[Bool] - Whether to inline binding that occurs only once. + Whether to inline a binding that is referenced exactly once. + ignore_impurity: Optional[Bool] + Whether to ignore possible side-effects in let-bound expressions. Returns ------- ret: tvm.transform.Pass The registered pass that eliminates the dead code in a Relay program. """ - return _ffi_api.DeadCodeElimination(inline_once) + return _ffi_api.DeadCodeElimination(inline_once, ignore_impurity) def LazyGradientInit(): diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 0574fd50f4b6..315ad9c3b6a5 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -90,10 +90,15 @@ RELAY_REGISTER_OP("memory.alloc_storage") .set_attrs_type_key("relay.attrs.AllocStorageAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) - .set_attr("TOpIsStateful", false) + .set_attr("TOpIsStateful", true) .set_attr("TNonComputational", true) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); +const Op& MemoryAllocTensorOp() { + static const Op& op = Op::Get("memory.alloc_tensor"); + return op; +} + Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype, Array assert_shape) { auto attrs = make_object(); @@ -106,8 +111,7 @@ Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype, ICHECK(constant_node); attrs->const_shape = GetRef(constant_node); } - static const Op& op = Op::Get("memory.alloc_tensor"); - return Call(op, {storage, offset, shape}, Attrs(attrs), {}); + return Call(MemoryAllocTensorOp(), {storage, offset, shape}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor").set_body_typed(AllocTensor); @@ -196,7 +200,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor") .set_attrs_type_key("relay.attrs.AllocTensorAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) - .set_attr("TOpIsStateful", false) + .set_attr("TOpIsStateful", true) .set_attr("TNonComputational", true) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); diff --git a/src/relay/op/memory/memory.h b/src/relay/op/memory/memory.h index 618044a9f2ca..9e93afdcfa37 100644 --- a/src/relay/op/memory/memory.h +++ b/src/relay/op/memory/memory.h @@ -35,6 +35,8 @@ namespace tvm { namespace relay { Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint); +/*! \brief Returns the "memory.alloc_tensor" operator. */ +const Op& MemoryAllocTensorOp(); Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, Array assert_shape); Expr ToTupleType(const Type& ty, const std::vector& exprs); diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc index 65a4ec01805b..281cca5e847a 100644 --- a/src/relay/op/vm/vm.cc +++ b/src/relay/op/vm/vm.cc @@ -61,7 +61,7 @@ Expr ShapeOf(Expr expr) { auto attrs = make_object(); attrs->dtype = DataType::Int(64); static const Op& op = Op::Get("vm.shape_of"); - return Call(op, {expr}, Attrs(attrs), {}); + return Call(op, {std::move(expr)}, Attrs(std::move(attrs)), {}); } TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed(ShapeOf); @@ -156,7 +156,9 @@ bool InvokeTVMOpRel(const Array& types, int num_inputs, const Attrs& attrs if (func_type->ret_type.as()) { ex_output = TupleType({func_type->ret_type}); } else { - ICHECK(func_type->ret_type.as()) << "should be tuple type"; + ICHECK(func_type->ret_type.as()) + << "expecting function result to be tuple type. Types:" << std::endl + << PrettyPrint(types); ex_output = func_type->ret_type; } auto ex_input = TupleType(func_type->arg_types); @@ -167,10 +169,14 @@ bool InvokeTVMOpRel(const Array& types, int num_inputs, const Attrs& attrs } Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs) { - return Call(Op::Get("vm.invoke_tvm_op"), {func, inputs, outputs}, Attrs()); + static const Op& op = Op::Get("vm.invoke_tvm_op"); + return Call(op, {std::move(func), std::move(inputs), std::move(outputs)}, {}); } -TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op").set_body_typed(InvokeTVMOp); +TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op") + .set_body_typed([](Expr func, Expr inputs, Expr outputs) { + return InvokeTVMOp(std::move(func), std::move(inputs), std::move(outputs)); + }); RELAY_REGISTER_OP("vm.invoke_tvm_op") .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE) @@ -179,9 +185,10 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op") .add_argument("ins", "Tuple", "The input tensors.") .add_argument("outs", "Tuple", "The output tensors.") .add_type_rel("InvokeTVMOp", InvokeTVMOpRel) + .set_attrs_type_key("DictAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) - .set_attr("TOpIsStateful", false) + .set_attr("TOpIsStateful", true) .set_attr("TNonComputational", true) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); @@ -217,7 +224,7 @@ Expr ReshapeTensor(Expr data, Expr shape, Array newshape) { static const Op& op = Op::Get("vm.reshape_tensor"); auto attrs = make_object(); attrs->newshape = std::move(newshape); - return Call(op, {data, shape}, Attrs(attrs), {}); + return Call(op, {std::move(data), std::move(shape)}, Attrs(std::move(attrs)), {}); } TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor").set_body_typed(ReshapeTensor); diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index 26624e438b8a..d73fc3ea9c38 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -18,158 +18,566 @@ */ /*! + * \file src/relay/transforms/dead_code.cc + * \brief Elides or inlines let-bindings. * - * \file dead_code.cc - * - * \brief Remove code that does not effect the program result. - * - * The algorithm is implemented by two visitor: - * CalcDep turn an expr into a dependency graph of expr, - * GenLet turn the dependency graph into a let list, taking only the used value. + * TODO(mbs): Track dead writes into references. */ + #include #include +#include #include -#include "let_list.h" +#include "../op/call/call.h" namespace tvm { namespace relay { +namespace { -template -using VarMap = std::unordered_map; -using VarSet = std::unordered_set; +/*! \brief Maximum depth of calls to analyize. */ +constexpr int kMaxCallDepth = 25; + +/*! + * \brief Captures (an approximation of) the purity for a Relay sub-expression. A pure + * sub-expression is guaranteed never to access or mutate state. Thus the sub-expression + * can safely be elided (if its result is never used), or inlined (which may change the + * number of times and program order for the evaluation.) + */ +struct Purity { + /*! + * \brief True if evaling the sub-expression itself is pure. + */ + bool pure_eval; + /*! + * \brief If the sub-expression is first-order then always true. Otherwise true only if evaling + * a call to the the sub-expression is pure. See [RULE A] below. + */ + bool pure_call; +}; + +/*! + * \brief Visits all the global functions in a module and records the purity of every let-bound + * value. + * + * (See also inline.cc for function inlining.) + * + * Generally we track whether evaluation of a sub-expression is definitely pure. However for + * sub-expressions f of higher-order type we also track the 'call purity' of evaling a call to f: + * - [RULE A] If f's result is itself higher-order then f is call-pure only if the result of f is + * also call-pure. + * - [RULE B] Higher-order function arguments are assumed call impure. + * - [RULE C] We assume functions extracted from tuples are call impure. + * - [RULE D] We assume functions extracted from references are call impure. + * - [RULE E] We assume functions extracted from ADTs are call impure. + * - [RULE F] We assume all external Functions and PrimFuncs are call impure. + */ +class PurityVisitor : ExprFunctor { + public: + explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)), current_call_depth_(0) {} + + /*! \brief Visit all the functions in the module. */ + void VisitModule() { + VLOG_CONTEXT << "PurityVisitor"; + // It is safe to visit the global functions in any order. Recursive global functions are + // allowed. + for (const auto& kv : mod_->functions) { + if (const auto* function_node = kv.second.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive) || + function_node->GetAttr(attr::kExternalSymbol)) { + // Ignore primitive and external functions. + continue; + } + // Everything of interest will be recorded in the purity maps so we ignore the result. + (void)VisitGlobalFunction(kv.first, GetRef(function_node)); + } + } + } + + /*! + * \brief Returns a map from every let-bound variable to whether its let-bound value is + * definitely pure. + */ + std::unordered_map GetPurityMap() const { + std::unordered_map result; + for (const auto& kv : var_to_purity_) { + result.emplace(kv.first, kv.second.pure_eval); + } + return result; + } -class CalcDep; -class FindDef : private ExprVisitor { private: - VarMap expr_map_; + Purity VisitExpr(const Expr& expr) final { + auto it = memo_.find(expr.get()); + if (it != this->memo_.end()) { + return it->second; + } else { + Purity result = ExprFunctor::VisitExpr(expr); + memo_[expr.get()] = result; + return result; + } + } - void VisitExpr_(const LetNode* l) final { - auto pre_visit = [this](const LetNode* op) { - ICHECK_EQ(expr_map_.count(op->var), 0); - expr_map_[op->var] = op->value; - this->VisitExpr(op->value); - }; - auto post_visit = [this](const LetNode* op) { - this->VisitExpr(op->body); - this->visit_counter_[op] += 1; - }; - ExpandANormalForm(l, pre_visit, post_visit); + Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true, /*pure_call=*/true}; } + + Purity VisitExpr_(const ConstructorNode*) final { + return {/*pure_eval=*/true, /*pure_call=*/true}; + } + + Purity VisitExpr_(const OpNode* op_node) final { + // Primitive operators are pure unless marked as 'stateful'. + static OpAttrMap attr_map = Op::GetAttrMap("TOpIsStateful"); + bool is_stateful = attr_map.count(GetRef(op_node)) && attr_map[GetRef(op_node)]; + return {/*pure_eval=*/true, /*pure_call=*/!is_stateful}; + } + + Purity VisitExpr_(const GlobalVarNode* global_var_node) final { + auto global_var = GetRef(global_var_node); + auto func = mod_->Lookup(global_var); + if (const auto* function_node = func.as()) { + if (!function_node->GetAttr(attr::kExternalSymbol)) { + return VisitGlobalFunction(global_var, GetRef(function_node)); + } + } + // Assume externals and PrimFuncs are call-impure [RULE F]. + // (If they are pure then we should have dealt with them before lowering.) + return {/*pure_eval==*/true, /*pure_call=*/false}; + } + + Purity VisitExpr_(const VarNode* var_node) final { + // The var is bound to a value, but if that value is a function we need to propagate the + // function body's purity. + ICHECK(var_to_purity_.count(var_node)) << PrettyPrint(GetRef(var_node)); + return {/*pure_eval=*/true, /*pure_call=*/var_to_purity_[var_node].pure_call}; + } + + Purity VisitExpr_(const FunctionNode* function_node) final { + for (const auto& param : function_node->params) { + // Any higher-order parameters are assumed to be call-impure [RULE B] + var_to_purity_[param.get()] = {/*pure_eval=*/true, /*pure_call=*/IsFirstOrder(param)}; + } + Purity body_purity = VisitExpr(function_node->body); + // The function itself is a value and thus pure. If the function returns + // a function we'll fold its purity in here [RULE A] + return {/*pure_eval=*/true, /*pure_call=*/body_purity.pure_eval && body_purity.pure_call}; + } + + Purity VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + bool all_values_pure_eval = true; + while (const auto* inner_let_node = expr.as()) { + // In case the value is a recursive function assume the let-bound variable is call-pure. + var_to_purity_[inner_let_node->var.get()] = {/*pure_eval=*/true, /*pure_call=*/true}; + Purity value_purity = VisitExpr(inner_let_node->value); + // Now revise the variable to it's true purity. + var_to_purity_[inner_let_node->var.get()] = value_purity; + VLOG(2) << (value_purity.pure_eval ? "pure" : "impure") << " expression:" << std::endl + << PrettyPrint(inner_let_node->value) << std::endl + << "let-bound to variable:" << std::endl + << PrettyPrint(inner_let_node->var); + all_values_pure_eval = all_values_pure_eval && value_purity.pure_eval; + expr = inner_let_node->body; + } + Purity body_purity = VisitExpr(expr); + return {/*pure_eval=*/all_values_pure_eval && body_purity.pure_eval, + /*pure_call=*/body_purity.pure_call}; + } + + Purity VisitExpr_(const CallNode* call_node) final { + if (current_call_depth_ >= kMaxCallDepth) { + // Assume impure. + VLOG(2) << "assuming call is impure since too deeply nested"; + return {/*pure_eval=*/false, /*pure_call*/ IsFirstOrder(GetRef(call_node))}; + } + + ++current_call_depth_; + + // We can work with the call in both pre- and post-lowered form. + Expr callee; + Array args; + if (call_node->op == CallLoweredOp()) { + CallLoweredProps props = GetCallLoweredProps(call_node); + callee = props.lowered_func; + args = props.arguments; + } else { + callee = call_node->op; + args = call_node->args; + } + + // Find purity for the callee and the args. + Purity callee_purity = VisitExpr(callee); + bool all_args_pure_eval = true; + for (const auto& arg : args) { + Purity arg_purity = VisitExpr(arg); + all_args_pure_eval = all_args_pure_eval && arg_purity.pure_eval; + } + + VLOG(2) << (callee_purity.pure_call ? "pure" : "impure") << " call to:" << std::endl + << PrettyPrint(callee); + + ICHECK_GT(current_call_depth_, 0); + --current_call_depth_; + + // If the callee's result is itself a function then by [RULE A] its purity + // is given by callee_purity.pure_call. + return {/*pure_eval=*/all_args_pure_eval && callee_purity.pure_eval && callee_purity.pure_call, + /*pure_call=*/IsFirstOrder(GetRef(call_node)) || callee_purity.pure_call}; + } + + Purity VisitExpr_(const IfNode* if_node) final { + Purity cond_purity = VisitExpr(if_node->cond); + ICHECK(cond_purity.pure_call); // conditional is first-order + Purity true_purity = VisitExpr(if_node->true_branch); + Purity false_purity = VisitExpr(if_node->false_branch); + return {/*pure_eval=*/cond_purity.pure_eval && true_purity.pure_eval && false_purity.pure_eval, + /*pure_call=*/true_purity.pure_call && false_purity.pure_call}; + } + + Purity VisitExpr_(const TupleNode* tuple_node) final { + bool all_fields_pure = true; + for (const auto& field : tuple_node->fields) { + // The call purity of each tuple field is lost [RULE C]. + Purity field_purity = VisitExpr(field); + if (!field_purity.pure_eval) { + all_fields_pure = false; + } + } + return {/*pure_eval=*/all_fields_pure, /*pure_call=*/true}; + } + + Purity VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + Purity tuple_purity = VisitExpr(tuple_get_item_node->tuple); + ICHECK(tuple_purity.pure_call); // tuple is first-order + // We don't track call purity through tuple fields, so if the result is a function type we + // must assume it is call impure [RULE C]. + return {/*pure_eval=*/tuple_purity.pure_eval, + /*pure_call=*/IsFirstOrder(GetRef(tuple_get_item_node))}; + } + + Purity VisitExpr_(const RefCreateNode*) final { + // The creation of the ref itself is unobservable other than via the reads/writes into it. + return {/*pure_eval=*/true, /*pure_call=*/true}; + } + + Purity VisitExpr_(const RefWriteNode* ref_write_node) final { + Purity ref_purity = VisitExpr(ref_write_node->ref); + ICHECK(ref_purity.pure_call); // reference is first-order + // The call purity of the written value is lost [RULE D]. + // (But we must still visit to accumulate purity for any let-bindings within in.) + (void)VisitExpr(ref_write_node->value); + return {/*pure_eval=*/false, /*pure_call=*/true}; + } + + Purity VisitExpr_(const RefReadNode* ref_read_node) final { + Purity ref_purity = VisitExpr(ref_read_node->ref); + ICHECK(ref_purity.pure_call); // reference is first-order + // We don't track call purity through reference values, so if the result is a function + // type we must assume it is call impure [RULE D]. + return {/*pure_eval=*/false, /*pure_call=*/IsFirstOrder(GetRef(ref_read_node))}; } - friend CalcDep; + class PurityPatternVisitor : public PatternVisitor { + public: + explicit PurityPatternVisitor(PurityVisitor* outer) : outer_(outer) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + // We don't track call purity through ADTs, so if var is a function type we must assume + // it is call impure [RULE E]. + outer_->var_to_purity_[pattern_var_node->var.get()] = { + /*pure_eval=*/true, /*pure_call=*/IsFirstOrder(pattern_var_node->var)}; + } + + /*! \brief (Mutable borrow of) the outer visitor. */ + PurityVisitor* outer_; + }; + + Purity VisitExpr_(const MatchNode* match_node) final { + Purity data_purity = VisitExpr(match_node->data); + ICHECK(data_purity.pure_call); // ADT is first order + bool all_clauses_pure_eval = true; + bool all_clauses_pure_call = true; + for (const auto& clause : match_node->clauses) { + PurityPatternVisitor pattern_visitor(this); + pattern_visitor.VisitPattern(clause->lhs); + Purity rhs_purity = VisitExpr(clause->rhs); + all_clauses_pure_eval = all_clauses_pure_eval && rhs_purity.pure_eval; + all_clauses_pure_call = all_clauses_pure_call && rhs_purity.pure_call; + } + return {/*pure_eval=*/data_purity.pure_eval && all_clauses_pure_eval, + /*pure_call=*/all_clauses_pure_call}; + } + + /*! \brief Visits \p func bound to global \p var and returns it's purity. */ + Purity VisitGlobalFunction(const GlobalVar& var, const Function& func) { + VLOG_CONTEXT << "func " << var->name_hint; + VLOG(2) << "visiting"; + auto itr = global_var_to_purity_.find(var.get()); + if (itr != global_var_to_purity_.end()) { + // We've already visited the function body. + return itr->second; + } + // We are entering the body of a possibly-recursive global function. Assume it's body is pure. + global_var_to_purity_[var.get()] = {/*pure_eval=*/true, /*pure_call=*/true}; + // Visit the global function for the first time. + Purity func_purity = VisitExpr(func); + // Update with the true purity. + global_var_to_purity_[var.get()] = func_purity; + return func_purity; + } + + static bool IsFirstOrder(const Expr& expr) { + return expr->checked_type().as() == nullptr; + } + + /*! \brief The module we're analyzing. */ + IRModule mod_; + + /*! + * \brief Maps each let-bound and global variable to the purity of the value it is bound to. + * If the variable is bound to a function then the purity of saturating that function is also + * tracked. + * + * Note that global_var_to_purity_, and all the 'pure_call' fields, are only needed internally + * during the analysis, andonly the var_to_purity_ 'pure_eval' fields are used downstream. + */ + std::unordered_map var_to_purity_; + std::unordered_map global_var_to_purity_; + + /*! \brief The current call depth. We'll just assume deeply nested calls are impure rather than + * spending all that time to check for sure. A deeply nested call is almost certain to be needed + * anyway. + */ + + int current_call_depth_; + + /*! \brief Internal map used for memoization. */ + std::unordered_map memo_; +}; + +/*! + * \brief Accumulate the bound values and usage count for each let-bound variable. + * + * We don't attempt to track the number of calls to local functions, and instead just assume they + * are called at least twice. + */ +class UsageVisitor : public ExprVisitor { + public: + /*! \brief Accumulates the expression bound to every let-bound variable. */ + std::unordered_map let_bound_values_; + /*! \brief Accumulates the usage count for every let-bound variable. */ + std::unordered_map use_map_; + + explicit UsageVisitor(const std::unordered_map* var_to_purity, + bool default_purity) + : var_to_purity_(var_to_purity), default_purity_(default_purity) {} + + void VisitExpr(const Expr& expr) final { + // Once we've seen 2 usages of a variable we know it can be neither elided nor inlined, + // so can stop visiting again. + if (++visit_counter_[expr.get()] <= 2) { + ExprFunctor::VisitExpr(expr); + } + } + + void VisitExpr_(const FunctionNode* function_node) final { + ++current_scope_level_; + ExprVisitor::VisitExpr_(function_node); + ICHECK_GT(current_scope_level_, 0); + --current_scope_level_; + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + ++visit_counter_[inner_let_node]; + let_bound_values_[inner_let_node->var.get()] = inner_let_node->value; + VLOG(2) << "seen let-binding for:" << std::endl << PrettyPrint(inner_let_node->var); + use_map_[inner_let_node->var.get()] = 0; + scope_level_map_[inner_let_node->var.get()] = current_scope_level_; + if (is_pure(inner_let_node->var.get())) { + // We'll defer visiting the let-bound value until we've seen the first use of the let-bound + // variable and thus know it must be evaluated. + // no-op. + } else { + // The let-bound value is impure so must always be evaluated. Visit now. + VisitExpr(inner_let_node->value); + } + expr = inner_let_node->body; + } + VisitExpr(expr); + } + + void VisitExpr_(const VarNode* var_node) final { + if (let_bound_values_.count(var_node)) { + size_t& n = use_map_[var_node]; + ++n; + VLOG(2) << var_node->name_hint() << " = " << n; + if (n == 1 && is_pure(var_node)) { + // Now that we have at least one use of the let-bound var, we know the let-bound + // value is necessary. + VisitExpr(let_bound_values_[var_node]); + } + if (scope_level_map_[var_node] < current_scope_level_) { + // Since the variable was bound outside of the current local function, assume the + // function will be called at least twice. + ++n; + VLOG(2) << var_node->name_hint() << " = " << n << " (bound at level " + << scope_level_map_[var_node] << " but used at level " << current_scope_level_ + << ")"; + } + } + // else: nothing to be done for function parameters or variable in match patterns. + } + + bool is_pure(const VarNode* var_node) const { + auto itr = var_to_purity_->find(var_node); + return itr == var_to_purity_->end() ? default_purity_ : itr->second; + } + + /*! \brief (Immutable borrow of) the already determined purity for every let-bound variable. */ + const std::unordered_map* var_to_purity_; + /*! \brief The default purity for variables which are not in the above map. */ + bool default_purity_; + /*! + * \brief The current scope level. 0 for global functions. Incremented by one within each + * let-bound local function. Necessary so we can avoid inlining an expensive let-bound computation + * into a function which could be called more than once. + */ + int current_scope_level_ = 0; + /*! \brief Accumulates the scope level for every let-bound variable. */ + std::unordered_map scope_level_map_; }; -class Eliminator : private ExprMutator { +/*! \brief Eliminate/inline let-bound values when sound to do so. */ +class EliminatorMutator : public ExprMutator { + public: + EliminatorMutator(bool inline_once, + const std::unordered_map* let_bound_values, + const std::unordered_map* use_map, + const std::unordered_map* var_to_purity, + bool default_purity) + : inline_once_(inline_once), + let_bound_values_(let_bound_values), + use_map_(use_map), + var_to_purity_(var_to_purity), + default_purity_(default_purity) {} + private: - VarMap expr_map_; - VarMap use_map_; - bool inline_once_; - explicit Eliminator(const VarMap& expr_map, const VarMap& use_map, bool inline_once) - : expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) {} - friend CalcDep; + enum Action { kElide, kInline, kNoChange }; - bool HasLet(const Var& v) { - switch (use_map_[v]) { + /*! \brief What should we do with let-binding for \p var_node? */ + Action ActionFor(const VarNode* var_node) { + if (let_bound_values_->count(var_node) == 0) { + // Not let-bound var. + return kNoChange; + } + if (!is_pure(var_node)) { + // The let-bound value is impure -- we must leave it exactly where it is. + return kNoChange; + } + switch (use_map_->count(var_node) ? use_map_->at(var_node) : 0) { case 0: - return false; + return kElide; case 1: - return !inline_once_; + return inline_once_ ? kInline : kNoChange; default: - return true; + return kNoChange; } } - Expr VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); - return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]); + Expr VisitExpr_(const VarNode* var_node) final { + if (ActionFor(var_node) == kInline) { + VLOG(1) << "inlining let-bound variable:" << std::endl << PrettyPrint(GetRef(var_node)); + return VisitExpr(let_bound_values_->at(var_node)); + } else { + return GetRef(var_node); + } } Expr VisitExpr_(const LetNode* op) final { auto pre_visit = [this](const LetNode* op) { - if (HasLet(op->var)) { - Expr value = this->VisitExpr(op->value); + if (ActionFor(op->var.get()) != kElide) { + (void)VisitExpr(op->value); } }; auto post_visit = [this](const LetNode* op) { - Expr body = this->VisitExpr(op->body); + Expr body = VisitExpr(op->body); auto expr = GetRef(op); - Var v = op->var; - if (HasLet(v)) { - Expr value = this->VisitExpr(op->value); - this->memo_[expr] = Let(v, value, body); - } else { - this->memo_[expr] = body; + switch (ActionFor(op->var.get())) { + case kElide: + VLOG(1) << "eliding let-bound variable:" << std::endl << PrettyPrint(op->var); + memo_[expr] = body; + break; + case kInline: + // Already inlined at use-side. + memo_[expr] = body; + break; + case kNoChange: + Expr value = VisitExpr(op->value); + memo_[expr] = Let(op->var, value, body); + break; } }; ExpandANormalForm(op, pre_visit, post_visit); return memo_[GetRef(op)]; } -}; -// calculate the dependency graph from expression -class CalcDep : protected MixedModeVisitor { - public: - static Expr Eliminate(const Expr& e, bool inline_once) { - FindDef fd; - fd(e); - CalcDep cd(fd.expr_map_); - cd(e); - Eliminator el(fd.expr_map_, cd.use_map_, inline_once); - return el(e); + bool is_pure(const VarNode* var_node) const { + auto itr = var_to_purity_->find(var_node); + return itr == var_to_purity_->end() ? default_purity_ : itr->second; } - private: - explicit CalcDep(const VarMap& expr_map) : MixedModeVisitor(2), expr_map_(expr_map) {} - VarMap expr_map_; - VarMap use_map_; + bool inline_once_; + const std::unordered_map* let_bound_values_; + const std::unordered_map* use_map_; + const std::unordered_map* var_to_purity_; + bool default_purity_; +}; - using MixedModeVisitor::VisitExpr_; +} // namespace - void VisitLeaf(const Expr& e) final { - visit_counter_[e.get()]++; - // The dce code seprate variable into three parts: - // used 0 times (remove) - // used 1 times (inline) - // used 2 times (dont do anything). - if (visit_counter_[e.get()] <= 2) { - using TParent = ExprFunctor; - TParent::VisitExpr(e); - } - } +namespace transform { - void VisitExpr_(const LetNode* l) final { - Expr let_binding = GetRef(l); - const LetNode* let; - while ((let = let_binding.as())) { - let_binding = let->body; - visit_counter_[l] += 1; +// Declared in relay/transform.h +Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { + auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { + // Which let bindings are pure and can be safely elided? + std::unordered_map var_to_purity; + if (!ignore_impurity) { + VLOG(1) << "determine purity"; + PurityVisitor purity_visitor(mod); + purity_visitor.VisitModule(); + var_to_purity = purity_visitor.GetPurityMap(); } - VisitExpr(let_binding); - } - void VisitExpr_(const VarNode* v) final { - Var var = GetRef(v); - ++use_map_[var]; - if (use_map_[var] == 1 && expr_map_.count(var) > 0) { - VisitExpr(expr_map_[var]); - } - } -}; + IRModule result(/*functions=*/{}, mod->type_definitions, mod->Imports(), mod->source_map); + for (const auto& kv : mod->functions) { + if (const auto* function_node = kv.second.as()) { + auto function = GetRef(function_node); -Expr DeadCodeElimination(const Expr& e, bool inline_once) { - return CalcDep::Eliminate(e, inline_once); -} + VLOG(1) << "processing " << PrettyPrint(kv.first); -namespace transform { + VLOG(2) << "count usage"; + UsageVisitor usage_visitor(&var_to_purity, /*default_purity=*/ignore_impurity); + usage_visitor.VisitExpr(function); + + // Actually eliminate/inline the let-bindings. + VLOG(2) << "eliminate"; + EliminatorMutator eliminator_mutator(inline_once, &usage_visitor.let_bound_values_, + &usage_visitor.use_map_, &var_to_purity, + /*default_purity=*/ignore_impurity); + result->Add(kv.first, Downcast(eliminator_mutator.VisitExpr(function))); + } else { + // PrimFuncs come across unchanged. + result->Add(kv.first, kv.second); + } + } -Pass DeadCodeElimination(bool inline_once) { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(DeadCodeElimination(f, inline_once)); - }; - return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); + return result; + }; + return tvm::transform::CreateModulePass(pass_func, /*opt_level=*/1, "DeadCodeElimination", + {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 22b3983df1c3..476728664132 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -17,7 +17,6 @@ """ Support level10 operator test cases. """ import numpy as np -import pytest import tvm import tvm.testing import tvm.topi.testing @@ -59,7 +58,14 @@ def test_checkpoint_alpha_equal(): # run PE and DCE with tvm.transform.PassContext(opt_level=3): - passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] + # The expected output assumes DCE can elide 'dead writes' to references. At the time this unit test was + # written DCE would elide all writes, which though unsound in general happens to work for this case. Preserve + # that legacy behaviour here using 'ignore_impurity=True'. + # TODO(mbs): Revisit once DCE supports dead reference writes. + passes = [ + transform.PartialEvaluate(), + transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), + ] mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] @@ -118,7 +124,12 @@ def test_checkpoint_alpha_equal_tuple(): # run PE and DCE with tvm.transform.PassContext(opt_level=3): - passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] + # See comment in test_checkpoint_alpha_equal above. + # TODO(mbs): Revisit once DCE supports dead reference writes. + passes = [ + transform.PartialEvaluate(), + transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), + ] mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] @@ -612,4 +623,7 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 if __name__ == "__main__": - pytest.main([__file__]) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 127035c5d540..3893da45fcaa 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -15,31 +15,30 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te -from tvm import relay from tvm.relay import Function, transform -from tvm.relay.analysis import free_vars -from tvm.relay.op import log, add, equal, subtract from tvm.relay.testing import inception_v3 - import pytest +cpu_scope = tvm.target.make_se_scope(tvm.cpu(), tvm.target.Target("llvm")) +metatable = {"SEScope": [cpu_scope]} +core = tvm.IRModule() +core.import_from_std("core.rly") + -def optimize_source(source, passes): +def optimize_and_check(before_program, after_program, passes): + if isinstance(before_program, str): + before_program = tvm.parser.parse(before_program) + if isinstance(after_program, str): + after_program = tvm.parser.parse(after_program) if not isinstance(passes, list): passes = [passes] - optimize = tvm.transform.Sequential(passes) - module = tvm.parser.parse(source) - return optimize(module) - - -def optimize_and_check(before_source, after_source, passes): - optimize_module = optimize_source(before_source, passes) - after_module = tvm.parser.parse(after_source) - print(optimize_module) - print(after_module) - assert tvm.ir.structural_equal(after_module, optimize_module) + optimized_program = optimize(before_program) + print("Actual:") + print(optimized_program) + print("Expected:") + print(after_program) + assert tvm.ir.structural_equal(optimized_program, after_program, map_free_vars=True) def test_dead_let(): @@ -197,7 +196,153 @@ def @main() { optimize_and_check(before_program, after_program, transform.DeadCodeElimination()) +def test_inline_into_function(): + """Don't inline across function boundaries.""" + before_program = """ + #[version = "0.0.5"] + def @main() { + let %x = 1 + 1; + let %f = fn (%y: int) -> int { + let %z = %y + %y; + %x + %z + }; + (%f(2), %f(3)) + } + """ + + after_program = """ + #[version = "0.0.5"] + def @main() { + let %x = 1 + 1; + let %f = fn (%y: int) -> int { + %x + (%y + %y) + }; + (%f(2), %f(3)) + } + """ + + optimize_and_check( + before_program, after_program, transform.DeadCodeElimination(inline_once=True) + ) + + +def test_impure_op(): + """Don't elide calls to side-effecting operators.""" + before_program = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main() { + let %size: int64 = cast(1024, dtype="int64"); + let %alignment: int64 = cast(64, dtype="int64"); + let %x = memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][0]); + 0 + } + """, + "from_string", + core, + metatable, + ) + + after_program = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main() { + let %x = memory.alloc_storage(cast(1024, dtype="int64"), + cast(64, dtype="int64"), + se_scope=meta[SEScope][0]); + 0 + } + """, + "from_string", + core, + metatable, + ) + + optimize_and_check( + before_program, after_program, transform.DeadCodeElimination(inline_once=True) + ) + + +def test_impure_func(): + """Don't elide calls to side-effecting functions.""" + before_program = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @f() -> int { + let %size: int64 = cast(1024, dtype="int64"); + let %alignment: int64 = cast(64, dtype="int64"); + let %x = memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][0]); + 0 + } + def @main() -> int { + let %y = @f(); + 0 + } + """, + "from_string", + core, + metatable, + ) + + after_program = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @f() -> int { + let %x = memory.alloc_storage(cast(1024, dtype="int64"), + cast(64, dtype="int64"), + se_scope=meta[SEScope][0]); + 0 + } + def @main() -> int { + let %y = @f(); + 0 + } + """, + "from_string", + core, + metatable, + ) + + optimize_and_check( + before_program, after_program, transform.DeadCodeElimination(inline_once=True) + ) + + +def test_refs(): + """Don't elide expressions with reference create/read/write side effects""" + before_program = """ + #[version = "0.0.5"] + def @f(%r) -> int { + let %v = ref_read(%r); + let %u = ref_write(%r, %v + 1); + %v + } + def @main() -> int { + let %r = ref(0); + let %y = @f(%r); + let %z = @f(%r); + %z + } + """ + + after_program = before_program + + optimize_and_check( + before_program, + after_program, + [transform.InferType(), transform.DeadCodeElimination(inline_once=True)], + ) + + +def test_complexity(): + mod = transform.InferType()( + tvm.IRModule.from_expr(inception_v3.get_net(1, 1000, (3, 299, 299), "float32")) + ) + + optimize_and_check(mod, mod, transform.DeadCodeElimination(inline_once=True)) + + if __name__ == "__main__": import sys - pytest.main(sys.argv) + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index a0af2205a5d0..bc18f0a212af 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -22,7 +22,6 @@ from tvm.relay.testing import rand, run_infer_type import tvm.testing from tvm.testing import assert_allclose -import pytest def test_tc(): @@ -288,7 +287,9 @@ def test_after_partial_eval(): transform.PartialEvaluate(), transform.InferType(), transform.LazyGradientInit(), + transform.InferType(), transform.DeadCodeElimination(), + transform.InferType(), ] ) @@ -326,7 +327,13 @@ def test_before_partial_eval(): mod["main"] = back_func seq = tvm.transform.Sequential( - [transform.LazyGradientInit(), transform.PartialEvaluate(), transform.DeadCodeElimination()] + [ + transform.LazyGradientInit(), + transform.PartialEvaluate(), + transform.InferType(), + transform.DeadCodeElimination(), + transform.InferType(), + ] ) mod = seq(mod) back_func = mod["main"] @@ -438,4 +445,7 @@ def test_ones_like(): if __name__ == "__main__": - pytest.main([__file__]) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index ce36abd83c40..84ecc8477e50 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -15,10 +15,8 @@ # specific language governing permissions and limitations # under the License. -import pytest import numpy as np import tvm -from tvm import te from tvm import relay from tvm.relay.prelude import Prelude from tvm.relay import op, create_executor, transform @@ -49,11 +47,11 @@ def tipe(expr): return run_opt_pass(expr, [transform.PartialEvaluate(), transform.InferType()]) -def dcpe(expr, mod=None, grad=False): +def dcpe(expr, mod=None, grad=False, ignore_impurity=False): passes = [ transform.PartialEvaluate(), transform.InferType(), - transform.DeadCodeElimination(inline_once=True), + transform.DeadCodeElimination(inline_once=True, ignore_impurity=ignore_impurity), transform.InferType(), ] if grad: @@ -95,7 +93,9 @@ def test_ref(): body = Let(r, RefCreate(d), body) square = Function([d], body) expected = run_opt_pass(Function([d], d * d), transform.InferType()) - assert tvm.ir.structural_equal(dcpe(square), expected) + # TODO(mbs): Revisit once DCE eliminates dead writes. + actual = dcpe(square, ignore_impurity=True) + assert tvm.ir.structural_equal(actual, expected) def test_empty_ad(): @@ -104,7 +104,8 @@ def test_empty_ad(): t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) - g = dcpe(f, grad=True) + # TODO(mbs): Revisit once DCE eliminates dead writes. + g = dcpe(f, grad=True, ignore_impurity=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = run_opt_pass(expected, transform.InferType()) assert tvm.ir.structural_equal(g, expected) @@ -116,7 +117,8 @@ def test_ad(): t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d * d) - g = dcpe(f, grad=True) + # TODO(mbs): Revisit once DCE eliminates dead writes. + g = dcpe(f, grad=True, ignore_impurity=True) m = d * d x = relay.Var("x") o = op.ones_like(x) @@ -348,4 +350,7 @@ def test_tuple_match(): if __name__ == "__main__": - pytest.main([__file__]) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 4825cc29e6e4..2200320c448b 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -73,10 +73,15 @@ def destroy_ref(x): x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) + # TODO(mbs): Revisit once DCE can eliminate dead writes. x = run_opt_pass( x, tvm.transform.Sequential( - [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] + [ + transform.PartialEvaluate(), + transform.InferType(), + transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), + ] ), ) assert Feature.fRefCreate not in detect_feature(x) @@ -118,5 +123,7 @@ def destroy_ref(x): if __name__ == "__main__": - test_recursion() - test_cps_pe() + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))