From 59dfc3876534d4f866461e73c47079858073d4b6 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 15 Feb 2023 21:35:14 +0800 Subject: [PATCH] [Unity][Pass] Operator Fusion Passes (#14001) [Unity][Pass] Operator fusion passes This PR introduces three passes for operator fusion: 1. AnnotateTIROpPattern: analysis the operator kind from PrimFunc. 2. FuseOps: fuse operators for Relax functions, which adds a new fused relax primitive function. 3. FuseTIR: fuse corresponding TIR PrimFuncs for the fused relax. --- include/tvm/relax/analysis.h | 11 + include/tvm/tir/buffer.h | 14 +- python/tvm/relax/transform/transform.py | 43 + .../transform/annotate_tir_op_pattern.cc | 55 ++ src/relax/transform/fuse_ops.cc | 909 ++++++++++++++++++ src/relax/transform/fuse_tir.cc | 728 ++++++++++++++ .../test_transform_annotate_tir_op_pattern.py | 360 +++++++ tests/python/relax/test_transform_fuse_ops.py | 759 +++++++++++++++ tests/python/relax/test_transform_fuse_tir.py | 563 +++++++++++ tests/python/relax/test_tvmscript_parser.py | 1 - 10 files changed, 3441 insertions(+), 2 deletions(-) create mode 100644 src/relax/transform/annotate_tir_op_pattern.cc create mode 100644 src/relax/transform/fuse_ops.cc create mode 100644 src/relax/transform/fuse_tir.cc create mode 100644 tests/python/relax/test_transform_annotate_tir_op_pattern.py create mode 100644 tests/python/relax/test_transform_fuse_ops.py create mode 100644 tests/python/relax/test_transform_fuse_tir.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 24cfe5b9bf11f..a55fe6797d45b 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -260,6 +260,17 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana = nullptr); +/*! + * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. + * + * \param func The PrimFunc to be analyzed. + * \return The Op Pattern Kind. + * + * \note This analysis applies on TIR function but is primarily used by relax passes. + * As a result we place it under the relax namespace. + */ +TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); + /*! * \brief Check if the given PrimFunc is essentially doing a reshape operation. * The reshape operation also includes expand_dims, squeeze, flatten, etc. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index d7a2aec0b9725..e3a853e4c7ea2 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -34,6 +34,18 @@ namespace tvm { namespace tir { +#ifndef TVM_INDEX_DEFAULT_I64 +#define TVM_INDEX_DEFAULT_I64 1 +#endif +/*! \brief if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 */ +inline DataType DefaultIndexType() { +#if TVM_INDEX_DEFAULT_I64 + return DataType::Int(64); +#else + return DataType::Int(32); +#endif +} + // forward declare Stmt class Stmt; @@ -135,7 +147,7 @@ class BufferNode : public Object { /*! \return preferred index type for this buffer node */ DataType DefaultIndexType() const { - return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); + return shape.size() != 0 ? shape[0].dtype() : tvm::tir::DefaultIndexType(); } /*! \brief Determine the offset in the buffer of the given index. diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index cab18797c6728..0f973db290f8f 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -105,6 +105,49 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass: return _ffi_api.AttachGlobalSymbol() # type: ignore +def AnnotateTIROpPattern() -> tvm.ir.transform.Pass: + """Annotate Op Pattern Kind for TIR functions + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.AnnotateTIROpPattern() # type: ignore + + +def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass: + """This pass groups bindings in a dataflow block of Relax functions and generate a new grouped + Relax function for each group, according to the fusion algorithm described in the pass + implementation. By grouping bindings into new Relax functions, we substitute the bindings in + the function being manipulated into function calls to the new grouped function. + + A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + + Parameters + ---------- + fuse_opt_level : int + The level of fuse optimization. -1 indicates that the level will be + inferred from pass context. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for operator fusion. + """ + return _ffi_api.FuseOps(fuse_opt_level) # type: ignore + + +def FuseTIR() -> tvm.ir.transform.Pass: + """Fuse primitive relax function into a larger TIR function if possible + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for tir fusion. + """ + return _ffi_api.FuseTIR() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/src/relax/transform/annotate_tir_op_pattern.cc b/src/relax/transform/annotate_tir_op_pattern.cc new file mode 100644 index 0000000000000..b1c1ed29aff39 --- /dev/null +++ b/src/relax/transform/annotate_tir_op_pattern.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/annotate_tir_op_pattern.cc + * \brief Annotate Op Pattern for TIR functions. It is a pass works on TIR PrimFuncs, + * but they are needed for relax fusion. So we put them in the relax namespace. + */ +#include +#include +#include + +namespace tvm { +namespace relax { + +tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) { + if (f->HasNonzeroAttr("op_pattern")) { + return f; + } else { + relay::OpPatternKind kind = AnalyzeOpPatternKind(f); + return WithAttr(std::move(f), "op_pattern", Integer(static_cast(kind))); + } +} + +namespace transform { + +Pass AnnotateTIROpPattern() { + auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) { + return AnnotateOpPattern(std::move(f)); + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc new file mode 100644 index 0000000000000..f3559b72da3f9 --- /dev/null +++ b/src/relax/transform/fuse_ops.cc @@ -0,0 +1,909 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/fuse_ops.cc + * \brief This file contains a pass which groups bindings in a dataflow block of Relax + * functions and generate a new grouped Relax function for each group, according to the fusion + * algorithm described below. By grouping bindings into new Relax functions, we substitute the + * bindings in the function being manipulated into function calls to the new grouped function. + * + * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + */ + +#include +#include +#include +#include +#include + +#include + +#include "../../relay/analysis/graph_partitioner.h" +#include "../../support/arena.h" + +namespace tvm { +namespace relax { + +/* + Note on Fusing algorithm: + + The main challenge of general fusor is to handle possible diamond shape branches, + in the following graph, conv2d can be fused to elemwise add. + + conv2d + / | \ + / | \ + op op op + \ | / + \ | / + elemwise add + | + + However, at the point of conv2d we do not necessarily know that all the future paths + will merge at the elemwise add. The fusion algorithm applies post-dominator analysis. + + The immediate post-dominator of a node defined by the closest node where all the future path goes + into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm + is as follows: + + - Construct a DAG of dataflow graph for dominator analysis + - Construct a post-dominator tree which gives immediate post dominator of each node. + - Run fusion algorithm with the given post-dominator information. + + Note that, because we run analysis on a DAG, we use a single pass post-dominator + tree construction algorithm via LCA, which is simpler than the full version that handles cycles. + + The fusion algorithm traverses from each node and checks if it can be fused to its + immediate post dominator. It has to check the following things: + + - CheckPath: check all the path between a node and its immediate post-dominator + satisfies the fuse condition. + - Note that these intermediate node can already be fused with another nodes, the algorithm + will still run correctly. + - CommitFuse: mark all the nodes between source and post-dominator as the same group. + - We use an Union-Find data structure to manage the groups. +*/ + +using relay::GraphPartitioner; +using relay::IndexedForwardGraph; +using relay::OpPatternKind; +using support::LinkNode; + +constexpr uint32_t kMaxFusedOps = 256; + +TVM_REGISTER_PASS_CONFIG_OPTION("relax.FuseOps.max_depth", Integer); + +class GraphCreator : public ExprVisitor { + public: + /*! + * \brief Create a IndexedForwardGraph according to the input module. The graph will be used for + * graph partition and operator fusion. + * \param mod The module which the creation accords to + * \param arena The allocator of all the internal node objects + * \return The created IndexedForwardGraph + */ + static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) { + // Since cross-function call is not supported yet, FuseOps only serves the entry function, whose + // name is "main". + auto relax_func = Downcast(mod->Lookup("main")); + GraphCreator creator(mod, arena); + creator(relax_func); + + // The algorithm of the graph creator ensures that each created node will be added to the + // post-dfs order and will be set its op pattern. Thus we check whether all these containers + // have the same size. + size_t n_nodes = creator.graph_.node_map.size(); + ICHECK_EQ(n_nodes, creator.graph_.post_dfs_order.size()); + ICHECK_EQ(n_nodes, creator.initialized_nodes_.size()); + + return creator.graph_; + } + + private: + explicit GraphCreator(IRModule mod, support::Arena* arena) + : mod_(std::move(mod)), arena_(arena) {} + + void VisitExpr_(const FunctionNode* func) final { + for (const Var& param : func->params) { + IndexedForwardGraph::Node* param_node = CreateNode(param.get()); + // The parameter is passed in from the outside, and thus it's marked as an external reference, + // and it's pattern is `kOpaque`. + MarkAsExternRef(param_node); + SetNodePattern(param_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(param_node, param.get()); + } + ExprVisitor::VisitExpr_(func); + } + + void VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + VisitBindingBlock_(df_block); + } + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) + } + + // TODO(tvm-team): how to deal with MatchCast binding here + + void VisitBinding_(const VarBindingNode* binding) final { + IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); + + // If the variable is not a dataflow variable, it must be the output variable of this dataflow + // block + if (!binding->var->IsInstance()) { + this->MarkAsExternRef(node); + } + if (const auto* call = binding->value.as()) { + // Case 1. The expression is a CallNode + VisitCall(call, node); + } else if (const auto* tuple_get_item = binding->value.as()) { + // Case 2. The expression is a TupleGetItemNode + VisitTupleGetItem(tuple_get_item, node); + } else { + VisitUnsupportedNode(binding->value, node); + // Case 3. The type of the expression is not fusion-supported. + // In this case, we skip adding edges, adding an empty node into graph. + } + AddToPostDFSOrder(node, binding->var.get()); + } + + /********** Non-Leaf Expression Nodes **********/ + + void VisitCall(const CallNode* call, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + OpPatternKind pattern = OpPatternKind::kOpaque; + Array args = call->args; + + // - If the op being called is a TIR PrimFunc, we get the function op pattern directly from the + // function attribute and visit the arguments one by one. + // - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we + // recurse into the call expression. + const auto* op = call->op.as(); + if (op == call_tir_op_.get()) { + const GlobalVar& global_var = Downcast(call->args[0]); + tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); + + // Override args for call_tir + args = Downcast(call->args[1])->fields; + + // TODO(tvm-team): handle the shape argument (args[3]) + Optional opt_pattern = func->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + pattern = static_cast(Downcast(opt_pattern)->value); + } else { + pattern = OpPatternKind::kOpaque; + } + } + // The pattern of the current binding variable node is set to the pattern of this operator. + SetNodePattern(binding_var_node, pattern); + // Visit all call args + for (const Expr& arg : args) { + ICHECK(IsLeaf(arg)); + VisitLeaf(arg, binding_var_node, pattern); + } + } + + void VisitTupleGetItem(const TupleGetItemNode* tuple_item, + IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + SetNodePattern(binding_var_node, OpPatternKind::kInjective); + VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective); + } + + void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + SetNodePattern(binding_var_node, OpPatternKind::kOpaque); + + auto visit_leaves = [this, &binding_var_node](const Expr& e) { + if (e->IsInstance() || e->IsInstance()) { + VisitLeaf(e, binding_var_node, OpPatternKind::kOpaque); + } + }; + PostOrderVisit(expr, visit_leaves); + } + + /********** Leaf Expression Nodes **********/ + + void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* binding_var_node, + const OpPatternKind& pattern) { + ICHECK_NOTNULL(binding_var_node); + + // Recursive visit if it's Tuple + if (const auto* tuple = leaf_expr.as()) { + for (const Expr& expr : tuple->fields) { + VisitLeaf(expr, binding_var_node, pattern); + } + return; + } + + auto it = graph_.node_map.find(leaf_expr.get()); + IndexedForwardGraph::Node* leaf_node = nullptr; + if (it != graph_.node_map.end()) { + leaf_node = it->second; + } else if (leaf_expr->IsInstance()) { + leaf_node = CreateNode(leaf_expr.get()); + // Since we never fuse constants, the pattern of the constant is set to `kOpaque`. + SetNodePattern(leaf_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(leaf_node, leaf_expr.get()); + } else { + LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr + << " used before definition."; + } + AddEdge(leaf_node, binding_var_node, pattern); + } + + /********** Helper Functions **********/ + + /*! + * \brief Check whether the expression is a leaf expression + * \param expr The expression to be checked + * \return Whether the expression is a leaf expression + * \note In order to avoid too much refactor, this method is a simple copy-paste of the is-leaf + * check in "block_builder.cc". And it should be refactored in the future. + * \sa src/relax/ir/block_builder.cc + */ + static bool IsLeaf(const Expr& expr) { + // NOTE: Tuples are treated as leaf nodes for ergonomics + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as() || + expr.as(); + } + + /*! + * \brief Create a graph node corresponding to the input key + * \param key The object which is used to create the graph node + * \return The created graph node + * \note The node corresponding to each key is supposed to be created for only once + */ + IndexedForwardGraph::Node* CreateNode(const Object* key) { + ICHECK(graph_.node_map.find(key) == graph_.node_map.end()) + << "The node corresponding to the input key is not supposed to be created before"; + auto* node = arena_->make(); + graph_.node_map[key] = node; + return node; + } + + /*! + * \brief Append the input node to the post-dfs order of the graph + * \param node The node to be appended + * \param key The key corresponding to the node + * \note Each node is supposed to be appended to the post-dfs order for only once + */ + void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) { + auto it = graph_.node_map.find(key); + ICHECK(it != graph_.node_map.end() && it->second == node) + << "The node must have been created before adding to the post-dfs order"; + + // We only set the reference of the node when adding it to the post-dfs order. Thus, if the + // reference of a node is already set, it must have been appended to the post-dfs order. + ICHECK(node->ref == nullptr) + << "The node is not supposed to be added into the post-dfs order before"; + + node->ref = key; + node->index = graph_.post_dfs_order.size(); + graph_.post_dfs_order.push_back(node); + } + + /*! + * \brief Add an edge from the input start to the input end in the graph, with specific pattern + * \param start The start of the edge + * \param end The end of the edge + * \param pattern The pattern of this edge + */ + void AddEdge(IndexedForwardGraph::Node* start, IndexedForwardGraph::Node* end, + OpPatternKind pattern) { + auto* link = arena_->make>(); + link->value.node = end; + link->value.pattern = pattern; + start->outputs.Push(link); + } + + /*! + * \brief Mark a given node as "external reference", which means the node cannot be fused as an + * intermediate node + * \param node The graph node to be marked + */ + void MarkAsExternRef(IndexedForwardGraph::Node* node) { node->extern_ref = true; } + + /*! + * \brief Set the pattern of the input node + * \param node The graph node to be set + * \param pattern The pattern of the node + */ + void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) { + ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end()) + << "The input node is supposed to be set pattern for only once"; + initialized_nodes_.insert(node); + node->pattern = pattern; + } + + private: + /*! \brief The IRModule from which the indexed forward graph is created */ + IRModule mod_; + /*! \brief The allocator of all the internal node objects */ + support::Arena* arena_; + /*! \brief The created indexed forward graph */ + IndexedForwardGraph graph_; + /*! \brief The graph nodes whose patterns are set */ + std::unordered_set initialized_nodes_; +}; + +/*! + * \brief The ExprMutator used to create a new grouped function + * \details The workflow of this ExprMutator is: + * - The bindings in the function will be added by OperatorFusor via `AppendBinding(...)`. + * - When adding a new binding through `AppendBinding(...)`, we check whether the variables and + * constants used by the binding are defined by some previous added binding. And for the undefined + * variables and constants, we add them to the argument list and created new variables as the + * corresponding parameters. + * - When `CreateFunction()` is called, we go through each binding and update the binding with the + * new parameters. After that we wrap all bindings with a DataflowBlock and a Function. + */ +class FunctionCreator : public ExprMutator { + public: + explicit FunctionCreator(bool lift_constant) : lift_constant_(lift_constant) {} + /*! + * \brief Append a new binding to this function and possibly create new parameters for the + * function accordingly + * \param binding The binding to be appended + * \note Allowed bindings are: + * - VarBinding with value being a call node calling `relax.call_tir`. + * - VarBinding with value being a tuple-get-item node. + * // TODO(tvm-team): handle match shape + */ + void AppendBinding(const Binding& binding) { + ICHECK(!function_.defined()) + << "The `function_` is supposed to be uncreated when adding bindings"; + + if (const auto* var_binding = binding.as()) { + if (const auto* call = var_binding->value.as()) { + if (call->op == Op::Get("relax.call_tir")) { + // Update the name of the function. + name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; + + const Tuple& args = Downcast(call->args[1]); + for (const Expr& arg : args->fields) { + CheckDefAndUpdateParam(arg); + } + // TODO(tvm-team): handle shape expr + } else { + if (call->op->IsInstance()) { + name_hint_ = name_hint_ + "_" + Downcast(call->op)->name; + } else if (call->op->IsInstance()) { + std::string gvar_name = Downcast(call->op)->name_hint; + if (auto pos = gvar_name.find("fused_"); pos == 0) { + name_hint_ = name_hint_ + "_" + gvar_name.substr(std::string("fused_").size()); + } else { + name_hint_ = name_hint_ + "_" + gvar_name; + } + } + + for (const Expr& arg : call->args) { + CheckDefAndUpdateParam(arg); + } + } + } else { + const auto* tuple_item = var_binding->value.as(); + ICHECK(tuple_item != nullptr); + CheckDefAndUpdateParam(tuple_item->tuple); + } + + // Mark the binding variable as defined. + defined_vars_.insert(var_binding->var.get()); + // Set var as output true if the binding is not a dataflow variable + if (!var_binding->var->IsInstance()) { + AppendOutput(var_binding->var); + } + } else { + // TODO(tvm-team): handle match_cast + } + bindings_.push_back(binding); + } + + /*! \brief Set a var defined in the group as output. */ + size_t AppendOutput(const Var& var) { + ICHECK(defined_vars_.count(var.get())); + auto output_idx = GetOutputIndex(var); + if (output_idx) { + return *output_idx; + } + output_vars_.push_back(var.get()); + return output_vars_.size() - 1; + } + + /*! + * \brief Create the grouped function according according to the collected bindings and parameters + * \param composite_name The name to identify the pattern this function is created from, if any. + * It will become the value of the kComposite attribute of the created function. + * \note The created function won't be returned immediately. It's stored in the `function_` field. + */ + void CreateFunction(Map group_attrs) { + // Step 1. Start constructing a new dataflow block. + builder_->BeginDataflowBlock(); + + // Step 2. Visit each binding and collect outputs one by one. + Array outputs(output_vars_.size(), Expr()); + for (const Binding& binding : bindings_) { + if (auto output_idx = GetOutputIndex(binding->var)) { + // Case 1. It is an output binding + // We only allow VarBinding as output. + const auto* var_binding = binding.as(); + ICHECK_NOTNULL(var_binding); + Var output_var = builder_->EmitOutput(VisitExpr(var_binding->value)); + var_remap_[var_binding->var->vid] = output_var; + outputs.Set(*output_idx, output_var); + } else { + // Case 2. It is an internel binding, add it to the binding list. + VisitBinding(binding); + } + } + + // Step 3. Finish constructing the new block. + BindingBlock new_block = builder_->EndBlock(); + ICHECK(!outputs.empty()) << "At least one output is required."; + Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs); + body = builder_->Normalize(body); + body = builder_->Normalize(SeqExpr({new_block}, body)); + group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); + function_ = Function(/*params=*/params_, // + /*body=*/body, // + /*ret_struct_info=*/NullOpt, // + /*attrs=*/DictAttrs(group_attrs)); + } + + /*! \brief The original bindings of the function */ + Array bindings_; + /*! \brief The parameters of the function */ + Array params_; + /*! \brief The arguments to call the function on the caller side */ + Array arguments_; + /*! \brief The name for the fused function */ + String name_hint_ = "fused"; + /*! \brief The constructed Relax function */ + Function function_{nullptr}; + + private: + std::optional GetOutputIndex(Var v) { + auto it = std::find(output_vars_.begin(), output_vars_.end(), v.get()); + if (it != output_vars_.end()) { + return std::distance(output_vars_.begin(), it); + } + return std::nullopt; + } + + /*! + * \brief Check whether the input expression is defined within this function. If not, create a new + * parameter for the expression. + * \param expr The expression to be checked + */ + void CheckDefAndUpdateParam(const Expr& expr) { + // If the expression has already served as an argument, no need to create another one for it. + if (std::find(arguments_.begin(), arguments_.end(), expr) != arguments_.end()) { + return; + } + + // If the expression is not a variable or is a undefined variable, it should be populated as a + // parameter of the relax function. + const auto* var = expr.as(); + if ((var == nullptr || defined_vars_.count(var) == 0) && + (lift_constant_ || !expr->IsInstance())) { + String name{nullptr}; + if (var != nullptr) { + name = var->name_hint(); + } else { + name = String("param_" + std::to_string(n_param_for_const_++)); + } + + Var param(std::move(name), GetStructInfo(expr)); + arguments_.push_back(expr); + params_.push_back(param); + } + } + + Expr VisitExpr(const Expr& expr) final { + // If the expression serves as an argument, return its correspondng parameter. + auto it = std::find(arguments_.begin(), arguments_.end(), expr); + if (it != arguments_.end()) { + return params_[it - arguments_.begin()]; + } + // Otherwise, recurse into this expression. + return ExprMutator::VisitExpr(expr); + } + + private: + /*! \brief The variables defined in this function */ + std::unordered_set defined_vars_; + /*! \brief The number of parameters reserved for constants */ + int n_param_for_const_ = 0; + /*! \brief The output vars */ + std::vector output_vars_; + /*! \brief Whether or not to lift bound constants to parameters */ + bool lift_constant_; +}; + +/*! + * \brief The ExprMutator used to fuse the operators in Relax functions + * \details Given the partition results on the indexed-forward graph, for each group whose size is + * larger than one, we create a new grouped function for it, containing all bindings in that group. + * And we substitute the bindings in a group with a single function call to the newly created + * grouped function. The workflow of this ExprMutator is: for each dataflow block, + * - we go through the bindings one by one. For each binding, if it is in a group whose size is + * larger than one, we add the binding to the function of the group it is in and update the + * parameters and arguments of that function; + * - then we finalize all the grouped functions by updating their bindings using BlockBuilder; + * - lastly, we go through the bindings again and substitute the bindings in a group with a single + * call to the corresponding grouped function. + * + * After transforming a Relax function, we update the function in the IRModule. Besides, we add all + * newly created grouped function to the IRModule. + */ +class OperatorFusor : public ExprMutator { + public: + using Group = GraphPartitioner::Group; + using GroupMap = std::unordered_map; + + OperatorFusor(IRModule mod, const GroupMap& obj2group, bool lift_constants = true) + : ExprMutator(mod), + mod_(std::move(mod)), + obj2group_(obj2group), + lift_constants_(lift_constants) {} + + /*! + * \brief Construct a new operator fusor. Given the indexed-forward graph and the graph partition + * result on that graph, the constructor creates a mapping from each leaf AST object + * (e.g. parameters, variables, constants) to the group of the node corresponding to the object + * in the graph. + * \param mod The IRModule to be transformed + * \param graph The indexed-forward graph of the input IRModule + * \param groups The grouped result of the group partition on the input indexed-forward graph. + */ + OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, const std::vector& groups, + bool lift_constant = true) + : OperatorFusor(mod, CreateGroupMap(graph, groups), lift_constant) {} + + /*! + * \brief The main transformation on the IRModule + * \return The new IRModule after transformation + */ + IRModule Transform() { + for (const auto& [gv, func] : mod_->functions) { + // Only visit Relax function without attr kPrimitive. + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + auto updated_func = Downcast(VisitExpr(func)); + builder_->UpdateFunction(gv, updated_func); + } + } + return builder_->GetContextIRModule(); + } + + private: + static GroupMap CreateGroupMap(const IndexedForwardGraph& graph, + const std::vector& groups) { + GroupMap obj2group; + for (int nid = 0; nid < static_cast(graph.post_dfs_order.size()); ++nid) { + Group* group_root = groups[nid]->FindRoot(); + ICHECK(group_root != nullptr); + ICHECK(graph.post_dfs_order[nid]->ref != nullptr); + obj2group[graph.post_dfs_order[nid]->ref] = group_root; + } + return obj2group; + } + + bool IsTupleOutput(Function f) { + auto sinfo = GetStructInfo(f).as(); + ICHECK(sinfo); + return sinfo->ret->IsInstance(); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + return VisitBindingBlock_(df_block); + } + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) + return block; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + group2func_.clear(); + + // Step 1. Collect the bindings for each grouped function. + CollectFuncBindings(block->bindings); + + // Step 2. Collect all group's boundary (i.e. the output vars for each group) + CollectFuncBoundary(block->bindings); + + // Step 3. Create the grouped function for each group. + for (auto& [g, creator] : group2func_) { + creator.CreateFunction(g->attrs); + } + + // Step 4. Start generating the new binding block. + // - For groups with single binding, we directly recurse into the binding and emit the new one. + // - For groups with multiple bindings, we emit the call to the grouped function only when + // visiting the last binding of the group, because only by doing this we don't break the + // dependencies among the bindings of different groups. And therefore, we will skip all but the + // last binding of the group. + builder_->BeginDataflowBlock(); + + // For each group, record which variables need to be remapped to the output of TupleGetItem. + // Only relevant when the output of the grouped function is a tuple. + std::unordered_map> pending_tuple_get; + + // A grouped function which returns a tuple requires attaching TupleGetItem to each element and + // remapping variables in earlier bindings approriately. Thus, a binding whose value depends on + // some elements of a tuple from other group's function must be emitted after a call to the + // tuple-producing function is emitted and remapping is done. + // To guarantee this, we process bindings in the order of the topological sort of the group + // dependency relations. + for (const auto& binding : TopoSortByGroupDep(block->bindings)) { + // Case 1. If the binding is the only binding in its group, recurse into it and emit the + // transformed binding as usual. + Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1 && group->attrs.empty()) { + VisitBinding(binding); + continue; + } + + const auto& it_creator = group2func_.find(group); + ICHECK(it_creator != group2func_.end()); + const FunctionCreator& func_info = it_creator->second; + + // If this binding belongs to a group whose output is a tuple, the original bound variable + // needs to be remapped to the output of TupleGetItem after the corresponding tuple is + // emitted. + if (IsTupleOutput(func_info.function_) && tuple_get_indices_.count(binding->var.get())) { + pending_tuple_get[group].push_back(binding->var); + } + + // Case 2. If the binding is not the last binding of the group, we skip it. + if (!func_info.bindings_.back().same_as(binding)) { + continue; + } + + // Case 3. The binding is the last binding of the group. + const auto* var_binding = binding.as(); + ICHECK(var_binding != nullptr) << "The last binding of a group whose size is larger than 1 " + "is supposed to be a variable binding"; + + // Step a. Add the grouped function to the IRModule + GlobalVar gv = builder_->AddFunction(func_info.function_, func_info.name_hint_); + + // Step b. Create the call to the deduplicated function, and then emit the call. + // - If this binding is an output binding, emit an output variable. + // - Otherwise, emit a dataflow variable. + Var new_var; + Call call_to_emit = Call(gv, UpdateArgs(func_info.arguments_)); + + if (var_binding->var->IsInstance()) { + new_var = builder_->Emit(call_to_emit); + } else { + new_var = builder_->EmitOutput(call_to_emit); + } + + // Step c. Update the mapping used for the remapping of the binding variables. + if (IsTupleOutput(func_info.function_)) { + // If the output is a tuple, attach TupleGetItem to all tuple elements, and + // remap variables approriately. + // The variables that need to be remapped and the corresponding tuple indices are + // available in pending_tuple_get and tuple_get_indices_ respectively. + for (const auto& var : pending_tuple_get[group]) { + auto tuple_get = TupleGetItem(new_var, tuple_get_indices_[var.get()]); + var_remap_[var->vid] = builder_->Emit(tuple_get); + } + } else { + var_remap_[var_binding->var->vid] = new_var; + } + } + // Step 5. Finish the binding block generation. + return builder_->EndBlock(); + } + + /*! + * \brief Collect the bindings for each grouped function and update the information of the grouped + * function + * \param bindings The bindings to be collected + * \note The function update is done by `AppendBinding(...)` + */ + void CollectFuncBindings(const Array& bindings) { + for (const Binding& binding : bindings) { + // If the binding is the only binding in its group, there is no need to create a new function. + Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1 && group->attrs.empty()) { + continue; + } + // Add the binding to the grouped function it's in, and update the function information + // accordingly. + if (!group2func_.count(group)) { + group2func_.emplace(group, lift_constants_); + } + group2func_.find(group)->second.AppendBinding(binding); + } + } + + void CollectFuncBoundary(const Array& bindings) { + for (const Binding& binding : bindings) { + // Step 1. Get current binding's group + Group* cur_group = GetGroupFromBinding(binding); + + // Step 2. Collect all used vars in the binding value and update bondary. + // - If the var's group is same as the binding's, the var is defined in the same group + // - If the var's group is different with the binding's, the var must be the output from + // another group. Mark it to be the group output. + auto update_boundary = [this, binding, &cur_group](const Expr& e) { + if (e->IsInstance()) { + const Var& used_var = Downcast(e); + Group* producer_group = GetGroupFromVar(used_var); + // Only check those group defined before. + // Skip the vars from input or groups with single binding. + if (producer_group != cur_group) { + ICHECK(!group_deps_[producer_group].count(cur_group)) + << "A cyclic dependency detected between the groups " << binding->var->name_hint() + << " and " << used_var->name_hint() << " are in."; + group_deps_[cur_group].insert(producer_group); + } + + if (auto producer = group2func_.find(producer_group); + producer_group != cur_group && producer != group2func_.end()) { + auto output_index = producer->second.AppendOutput(used_var); + tuple_get_indices_[used_var.get()] = output_index; + } + } + }; + + if (const auto* var_binding = binding.as()) { + PostOrderVisit(var_binding->value, update_boundary); + } else { + const auto* match_cast = binding.as(); + ICHECK_NOTNULL(match_cast); + PostOrderVisit(match_cast->value, update_boundary); + } + } + } + + /*! + * \brief Get the group which the input binding is in + * \param binding The binding to be queried + * \return The pointer to the group which the input binding is in + */ + Group* GetGroupFromBinding(const Binding& binding) { + Var var = binding->var; + return GetGroupFromVar(var); + } + + /*! + * \brief Get the group which the input var is in + * \param Var The var to be queried + * \return The pointer to the group which the input var is in + */ + Group* GetGroupFromVar(const Var& var) { + const auto& it_group = obj2group_.find(var.get()); + ICHECK(it_group != obj2group_.end()); + Group* group = it_group->second; + return group->FindRoot(); + } + + /*! + * \brief Update the pre-stored arguments according to the variable remapping of the fusor, by + * recursing into each argument + * \param args The arguments to be updated + * \return The updated arguments + */ + Array UpdateArgs(const Array& args) { + Array new_args; + new_args.reserve(args.size()); + for (const Expr& arg : args) { + new_args.push_back(VisitExpr(arg)); + } + return new_args; + } + + private: + // Topologically sort bindings according to the group dependency relations. + Array TopoSortByGroupDep(const Array& bindings) { + std::unordered_map> bindings_per_group; + // The order to visit groups should respect the original order of bindings as much as possible. + std::vector group_order; + for (const auto& binding : bindings) { + auto g = GetGroupFromBinding(binding); + group_order.push_back(g); // Duplication does not matter since each group is visited once. + bindings_per_group[g].push_back(binding); + } + + std::unordered_set visited; + + std::function)> dfs_visit; + dfs_visit = [this, &visited, &dfs_visit](Group* g, auto leaf_fun) { + if (!visited.count(g)) { + visited.insert(g); + for (auto dep : group_deps_[g]) { + dfs_visit(dep, leaf_fun); + } + leaf_fun(g); + } + }; + + Array sorted; + + for (auto g : group_order) { + dfs_visit(g, [&sorted, &bindings_per_group](Group* leaf) { + for (const auto& binding : bindings_per_group[leaf]) { + sorted.push_back(binding); + } + }); + } + + return sorted; + } + + /*! \brief The IRModule. */ + IRModule mod_; + /*! \brief Internal arena. */ + support::Arena arena_; + /*! \brief The group assignment map. */ + GroupMap obj2group_; + /*! \brief Internal function information map. */ + std::unordered_map group2func_; + /*! \brief Record the index for TupleGetItem if the variable needs to be remapped to an output + * tuple element after fusion. */ + std::unordered_map tuple_get_indices_; + /*! \brief A map from a group to its dependent groups, used to detect cyclic dependencies. */ + std::unordered_map> group_deps_; + /*! \brief Whether or not to lift bound constants to parameters of the grouped function. */ + bool lift_constants_{true}; +}; + +IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { + support::Arena arena; + + // Step 1. Create the indexed-forward graph according to the input IRModule. + IndexedForwardGraph graph = GraphCreator::Create(mod, &arena); + + // Step 2. Partition the graph by applying the fusion algorithm. + std::vector groups = + GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph); + + // Step 3. Transform the IRModule by fusing the operators in accordance with the graph partition + // results. + return OperatorFusor(mod, graph, groups, /*lift_constants*/ true).Transform(); +} + +namespace transform { + +Pass FuseOps(int fuse_opt_level) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { + int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; + auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth", Integer(kMaxFusedOps)); + return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue()); + }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOps", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc new file mode 100644 index 0000000000000..fa5c296d278ef --- /dev/null +++ b/src/relax/transform/fuse_tir.cc @@ -0,0 +1,728 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "../../relay/analysis/graph_partitioner.h" +#include "../../support/arena.h" +#include "../../tir/ir/functor_common.h" + +namespace tvm { +namespace tir { + +// TODO(Siyuan): move it to somewhere under tir folder +/*! + * \brief Substitute a given source buffer with a given target buffer in statements or expressions. + */ +class FuseTIRBufferSubstitor : private StmtExprMutator { + public: + static Stmt Substitute(const Map& buffer_map, Stmt stmt) { + return FuseTIRBufferSubstitor(buffer_map)(std::move(stmt)); + } + + private: + explicit FuseTIRBufferSubstitor(const Map& buffer_map) { + for (const auto& kv : buffer_map) { + const Buffer& src = kv.first; + const Buffer& tgt = kv.second; + buffer_var_map_[src->data.get()] = tgt; + } + } + + PrimExpr VisitExpr_(const VarNode* _op) final { + auto it = buffer_var_map_.find(_op); + if (it != buffer_var_map_.end()) { + return it->second->data; + } else { + return GetRef(_op); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + auto it = buffer_var_map_.find(load->buffer->data.get()); + if (it != buffer_var_map_.end()) { + auto n = make_object(*load.get()); + n->buffer = it->second; + return BufferLoad(n); + } else { + return std::move(load); + } + } + + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + auto it = buffer_var_map_.find(store->buffer->data.get()); + if (it != buffer_var_map_.end()) { + auto n = CopyOnWrite(store.get()); + n->buffer = it->second; + return BufferStore(n); + } else { + return std::move(store); + } + } + + PrimExpr VisitExpr_(const LoadNode* _op) final { + Load load = Downcast(StmtExprMutator::VisitExpr_(_op)); + auto it = buffer_var_map_.find(load->buffer_var.get()); + if (it != buffer_var_map_.end()) { + auto n = make_object(*load.get()); + n->buffer_var = it->second->data; + return Load(n); + } else { + return std::move(load); + } + } + + Stmt VisitStmt_(const StoreNode* _op) final { + Store store = Downcast(StmtExprMutator::VisitStmt_(_op)); + auto it = buffer_var_map_.find(store->buffer_var.get()); + if (it != buffer_var_map_.end()) { + auto n = CopyOnWrite(store.get()); + n->buffer_var = it->second->data; + return Store(n); + } else { + return std::move(store); + } + } + + Stmt VisitStmt_(const BlockNode* _op) final { + Block block = Downcast(StmtMutator::VisitStmt_(_op)); + + // Define the mutation functions. + auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) { + const Buffer& src_buffer = match_buffer->source->buffer; + auto it = buffer_var_map_.find(src_buffer->data.get()); + if (it != buffer_var_map_.end()) { + return MatchBufferRegion(match_buffer->buffer, + BufferRegion(it->second, match_buffer->source->region)); + } else { + return match_buffer; + } + }; + + auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { + auto it = buffer_var_map_.find(buffer_region->buffer->data.get()); + return it == buffer_var_map_.end() ? buffer_region + : BufferRegion(it->second, buffer_region->region); + }; + + // Step 1. Mutate `match_buffers`. + Array match_buffers = + MutateArray(block->match_buffers, f_mutate_match_buffers); + // Step 2. Mutate the read/write region. + Array reads = MutateArray(block->reads, f_mutate_read_write_region); + Array writes = MutateArray(block->writes, f_mutate_read_write_region); + + reads = UnionAccessRegion(reads); + writes = UnionAccessRegion(writes); + + if (reads.same_as(block->reads) && // + writes.same_as(block->writes) && // + match_buffers.same_as(block->match_buffers)) { + return std::move(block); + } else { + auto n = CopyOnWrite(block.get()); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->match_buffers = std::move(match_buffers); + return Block(n); + } + } + + private: + /*! \brief Mapping from src buffer.data to tgt buffer. */ + std::unordered_map buffer_var_map_; + /*! \brief The structural equality checker */ + StructuralEqual structural_equal_; + + Array UnionAccessRegion(const Array& regions) const { + // For now we only allow Buffer access the same elements. + // e.g. `[A[vi, vj], A[vi, vj]]` is a legal pattern but need to union to `A[vi, vj]` + // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now. + // Note: the order of return region should remain the same as the first occurance of the region + Array ret; + std::unordered_map buffer_region_set; + + for (const BufferRegion& region : regions) { + auto it = buffer_region_set.find(region->buffer.get()); + if (it == buffer_region_set.end()) { + ret.push_back(region); + buffer_region_set[region->buffer.get()] = region->region; + } else { + ICHECK(structural_equal_(region->region, it->second)); + } + } + + if (ret.size() == regions.size()) { + return regions; + } else { + return ret; + } + } +}; + +/*! \brief A mutator which detect block name duplication and deduplicate the names. */ +class BlockNameDeduplicator : public tir::StmtMutator { + private: + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(tir::StmtMutator::VisitStmt_(op)); + + String name = GetUniqueName(block->name_hint); + + if (name == block->name_hint) { + return std::move(block); + } else { + ObjectPtr n = CopyOnWrite(block.get()); + n->name_hint = std::move(name); + return Stmt(n); + } + } + + String GetUniqueName(const String& prefix) { + String unique_prefix = prefix; + auto it = name_count_.find(prefix); + while (name_count_.count(unique_prefix)) { + unique_prefix = prefix + "_" + std::to_string(++it->second); + } + name_count_[unique_prefix] = 0; + return unique_prefix; + } + + // TODO(relax-team): It should detects the number suffix and do renaming properly + // e.g. GetUniqueName("name1") should return "name2" instead of "name10". + /*! \brief The count map to make block name unique. */ + std::unordered_map name_count_; +}; + +} // namespace tir + +namespace relax { + +class FusedTIRConstructor : public ExprVisitor { + public: + /*! + * \brief Construct a fused TIR PrimFunc from a relax sub-function + * \param mod The IRModule + * \param gv The global var of relax subfunction to be fused into one PrimFunc + * \return The fused TIR PrimFunc + */ + static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) { + FusedTIRConstructor visitor(mod, gv->name_hint); + BaseFunc f = mod->Lookup(gv); + CHECK(f->IsInstance()) + << "Expected relax functions, but got: " << f->GetTypeKey(); + CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) + << "Expected a function with attr `kPrimitive`"; + visitor(Downcast(f)); + return visitor.fused_tir_; + } + + private: + explicit FusedTIRConstructor(const IRModule& mod, const String& func_name) + : mod_(mod), func_name_(func_name) {} + + void VisitExpr_(const FunctionNode* func) final { + // Step 1. Create buffers for function params + for (const Var& relax_param : func->params) { + auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param), // + relax_param->name_hint()); + const Array& params = ret.first; + const Array& buffers = ret.second; + ICHECK_EQ(params.size(), buffers.size()); + for (size_t i = 0; i < params.size(); ++i) { + func_info_.buffer_map.Set(params[i], buffers[i]); + func_info_.params.push_back(params[i]); + } + func_info_.expr2buffers.Set(relax_param, buffers); + } + + // Step 2. Visit Function body and create intermediate buffers + ExprVisitor::VisitExpr_(func); + + // Step 3. Create and remap buffers for function output + ICHECK(func->body->IsInstance()) + << "Function body is expected to be a SeqExpr, but got: " << func->body->GetTypeKey(); + Expr body = Downcast(func->body)->body; + auto it = func_info_.expr2buffers.find(body); + ICHECK(it != func_info_.expr2buffers.end()) + << "Fail to detect output buffers for function body"; + const Array& buffers = (*it).second; + for (size_t i = 0; i < buffers.size(); ++i) { + tir::Var param = tir::Var("p_output" + std::to_string(i), PrimType(DataType::Handle())); + func_info_.buffer_map.Set(param, buffers[i]); + func_info_.params.push_back(param); + func_info_.output_buffers.insert(buffers[i].get()); + } + + // Step 4. Create PrimFunc + fused_tir_ = ConstructFunc(); + } + + void VisitBinding_(const VarBindingNode* binding) final { + // Update expr2buffers by visiting values. + this->VisitExpr(binding->value); + auto it = func_info_.expr2buffers.find(binding->value); + if (it != func_info_.expr2buffers.end()) { + // assign binding var to the buffers of the value + func_info_.expr2buffers.Set(binding->var, (*it).second); + } else { + LOG(FATAL) << "Unsupported binding value: " << binding->value; + } + } + + void VisitBinding_(const MatchCastNode* match_cast) final { + LOG(FATAL) << "MatchCast is unsupported in primitive functions"; + } + + void VisitExpr_(const CallNode* call) final { + ExprVisitor::VisitExpr_(call); + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + ICHECK(call->op == call_tir_op_) + << "Only call_tir is supported in primitive function, but got: " << GetRef(call); + + // Step 1. Get Global var and PrimFunc + GlobalVar gv = Downcast(call->args[0]); + Optional prim_func_ = GetPrimFunc(gv); + ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir in the module: " + << gv; + // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication + tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value()); + + // Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block + // TODO(Siyuan): support un-schedulable functions. + ICHECK(prim_func->body->IsInstance()) + << "Only schedulable functions (whose body is the root block) can be fused"; + const tir::BlockRealize& root_realize = Downcast(prim_func->body); + const tir::Block& root_block = root_realize->block; + + // Step 4. Add all the original alloc_buffers and body to the fused function. + func_info_.alloc_buffers.insert(func_info_.alloc_buffers.end(), + root_block->alloc_buffers.begin(), + root_block->alloc_buffers.end()); + func_info_.bodies.push_back(root_block->body); + + // Step 5. Map input arguments to buffer + MapInputBuffer(prim_func, call->args[1]); + size_t num_output_buffers = GetCallTIROutputSize(call); + AllocateIntermediateBuffer(GetRef(call), prim_func, num_output_buffers); + // Update fused func name + func_info_.global_name += "_" + gv->name_hint; + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item) final { + ExprVisitor::VisitExpr_(tuple_get_item); + auto it = func_info_.expr2buffers.find(tuple_get_item->tuple); + if (it != func_info_.expr2buffers.end()) { + int begin_buf_idx = 0; + int end_buf_idx = 0; + const TupleType& tuple_type = Downcast(tuple_get_item->tuple->checked_type()); + for (int i = 0; i < tuple_get_item->index; ++i) { + begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]); + } + end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]); + func_info_.expr2buffers.Set( + GetRef(tuple_get_item), + {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); + } + } + + void VisitExpr_(const TupleNode* tuple) final { + ExprVisitor::VisitExpr_(tuple); + Array buffers; + for (const Expr& expr : tuple->fields) { + auto it = func_info_.expr2buffers.find(expr); + if (it != func_info_.expr2buffers.end()) { + buffers.insert(buffers.end(), (*it).second.begin(), (*it).second.end()); + } + } + if (!buffers.empty()) { + func_info_.expr2buffers.Set(GetRef(tuple), buffers); + } + } + + void VisitExpr_(const ConstantNode* op) final { + LOG(FATAL) << "Relax.Constant is not supported in primitive functions."; + } + + /********** Helper Functions **********/ + + /*! + * \brief Pattern match op to a TIR function and look it up. + * \return The TIR function, or NullOpt if patter match fails. + */ + Optional GetPrimFunc(const GlobalVar& global_var) { + // NOTE: as check works for nullptr(returns null) + Optional base_func = mod_->functions.Get(global_var); + if (auto* pfunc = base_func.as()) { + return GetRef(pfunc); + } else { + return NullOpt; + } + } + + /*! + * \brief Get the number of outputs for a call_tir node. + * \return The number of outputs. + */ + static size_t GetCallTIROutputSize(const CallNode* call) { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + ICHECK(call->op.same_as(call_tir_op_)); + ICHECK_EQ(call->sinfo_args.size(), 1); + if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { + return tuple_sinfo->fields.size(); + } else { + return 1; + } + } + + /*! \brief Map old TIR func param buffer to new buffer, and then update `buffer_subst_map` */ + void MapArgsToBuffer(const Array args, const Array& buffers) { + size_t buffer_idx = 0; + for (const Expr& arg : args) { + if (const auto* v = arg.as()) { + auto it = func_info_.expr2buffers.find(GetRef(v)); + // Substitute the buffer with the already allocated one if it is an intermediate var + if (it != func_info_.expr2buffers.end()) { + for (const tir::Buffer& target_buffer : (*it).second) { + ICHECK_LT(buffer_idx, buffers.size()); + const tir::Buffer& buffer = buffers[buffer_idx]; + // TODO(relax-team): Add support for symbolic shape fusion + for (const PrimExpr& shape_expr : buffer->shape) { + ICHECK(shape_expr.as()) << "Only support constant shape fusion for now"; + } + func_info_.buffer_subst_map.Set(buffer, target_buffer); + buffer_idx++; + } + } + } + } + // Make sure every buffers are maped. + ICHECK_EQ(buffer_idx, buffers.size()); + } + + /*! + * \brief Update buffer mapping `func_info_.buffer_subst_map` for input args + * \param func The old TIR PrimFunc + * \param output_size The number of output params. All output params are at the end of param list. + */ + void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) { + Array arg_list; + Array buffer_list; + if (const auto* arg_tuple = args.as()) { + arg_list = arg_tuple->fields; + } else { + arg_list = {args}; + } + + ICHECK_GE(func->params.size(), arg_list.size()); + for (size_t i = 0; i < arg_list.size(); ++i) { + const tir::Var& param = func->params[i]; + const tir::Buffer& buffer = func->buffer_map.at(param); + buffer_list.push_back(buffer); + } + + MapArgsToBuffer(arg_list, buffer_list); + } + + /*! + * \brief Allocate buffer(s) and update `func_info.expr2buffers` if the PrimFunc output(s) are + * intermediate results. + * \param expr The relax Expr, which can be binding vars or binding values. + * \param func The old TIR PrimFunc + * \param output_size The number of output params. All output params are at the end of param list. + */ + void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, size_t output_size) { + size_t n = func->params.size(); + ICHECK_GE(n, output_size); + // Allocate intermediate buffer + Array alloc_buffers; + for (size_t i = 0; i < output_size; ++i) { + const tir::Var& param = func->params[n - output_size + i]; + const tir::Buffer& buffer = func->buffer_map.at(param); + func_info_.alloc_buffers.push_back(buffer); + alloc_buffers.push_back(buffer); + } + // Update expr2buffers + func_info_.expr2buffers.Set(expr, alloc_buffers); + } + + /*! + * \brief Create an TIR func params and buffers with specified relax type and shape + * \param struct_info The struct info + * \param name_hint The name hint for params and buffers + * \param index The index used for unique name_hint if type is Tuple. + * -1 means no need to add postfix since the relax param is not a Tuple. + * \return The created TIR func params and buffers + */ + static std::pair, Array> CreateParamsAndBuffers( + StructInfo struct_info, const String& name_hint, int index = -1) { + Array params; + Array buffers; + if (const auto* tensor = struct_info.as()) { + // Case 1. the relax param is a DynTensor, we directly create a tir var and buffer + const auto* shape_expr = tensor->shape.as(); + ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape."; + + String name = index == -1 ? name_hint : name_hint + "_" + std::to_string(index); + DataType dtype = tensor->dtype; + tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name); + // Differentiate buffer name and param name by adding prefix `v_` to param + // Every symbol should be unique in TVMScript, and Buffer is used more than param + // So we decide to make sure buffer names have better readability. + tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle())); + params.push_back(std::move(param)); + buffers.push_back(std::move(buffer)); + } else if (const auto* tuple = struct_info.as()) { + // Case 2. the relax param is a Tuple, we recursively visit each field until it's a DynTensor + // Enable postfix + if (index == -1) index = 0; + for (size_t i = 0; i < tuple->fields.size(); ++i) { + auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index); + const Array& ret_params = ret.first; + const Array& ret_buffers = ret.second; + ICHECK_EQ(ret_params.size(), ret_buffers.size()); + // Adding tuple field results to the end of params and buffers. + params.insert(params.end(), ret_params.begin(), ret_params.end()); + buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end()); + index += ret_params.size(); + } + } else { + ICHECK(false) << "shapes are expected to be ShapeExprNode or TupleNode"; + } + return std::make_pair(params, buffers); + } + + /*! + * \brief Construct fused TIR func with collected FuseFuncInfo + * \return The fused TIR + */ + tir::PrimFunc ConstructFunc() { + Map attr_map; + attr_map.Set("tir.noalias", tir::const_true()); + ICHECK(func_info_.global_name != "fused"); + // Remove output buffers from func_info_.alloc_buffers + Array alloc_buffers; + for (const tir::Buffer& buf : func_info_.alloc_buffers) { + if (func_info_.output_buffers.count(buf.get()) == 0) { + alloc_buffers.push_back(buf); + } + } + tir::Stmt body = tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); + body = tir::FuseTIRBufferSubstitor::Substitute(func_info_.buffer_subst_map, body); + body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, alloc_buffers); + body = tir::BlockRealize({}, Bool(true), Downcast(body)); + tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, + DictAttrs(attr_map)); + return func; + } + + /*! \brief Get DynTensor numbers from recursive Tuples. */ + static size_t GetTotalTensorSize(const Type& type) { + if (type.as()) { + return 1; + } else if (const auto* tuple_type = type.as()) { + size_t num = 0; + for (const Type& type : tuple_type->fields) { + num += GetTotalTensorSize(type); + } + return num; + } else { + LOG(FATAL) << "DynTensorType and TupleType are expect, but got: " << type; + return 0; + } + } + + /********** Function Info **********/ + + /*! \brief auxiliary information for FuseTIR */ + struct FuseFuncInfo { + /*! \brief The arguments for calling prim_func */ + Array arguments; + /*! + * \brief The map from each dataflow var (intermediate var) to the corresponding buffers + * allocated in the fused func + */ + Map> expr2buffers; + /*! \brief The buffers to allocate in the fused func*/ + Array alloc_buffers; + /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ + Array bodies; + /*! \brief The params of the fused function*/ + Array params; + /*! + * \brief The map from buffer in original functions to corresponding buffer in the fused + * function + */ + Map buffer_subst_map; + /*! \brief The `buffer_map` in the fused function*/ + Map buffer_map; + /*! \brief The output buffers in the function buffer_map*/ + std::unordered_set output_buffers; + /*! \brief The name of the fused function */ + std::string global_name = "fused"; + }; + + /*! \brief The IRModule */ + const IRModule& mod_; + /*! \brief The name hint for the input func. */ + String func_name_; + /*! \brief The helper info to fuse TIR prim_func */ + FuseFuncInfo func_info_; + /*! \brief The tir function after fusion*/ + tir::PrimFunc fused_tir_; +}; + +/*! + * \brief The helper class to fuse TIR functions and build a new module which calls the fused TIR. + */ +class TIRFuseMutator : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder. + TIRFuseMutator mutator(mod); + // Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_` + for (const auto& kv : mod->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + // Only fuse primitive relax functions + if (func->IsInstance() && func->HasNonzeroAttr(attr::kPrimitive)) { + tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv); + mutator.fused_tir_funcs_.Set(gv, fused_tir); + } + } + + // Step 2. Update all non-primitive relax functions and add it, with the dependent function, + // into the new IRModule + for (const auto& kv : mod->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + relax::Function update_func = Downcast(mutator.VisitExpr(func)); + mutator.builder_->AddFunction(update_func, gv->name_hint); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {} + + using ExprMutator::VisitExpr_; + + // Get shape from call tir + static Expr GetCallTIRShape(StructInfo sinfo) { + if (auto* tuple = sinfo.as()) { + Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); + return Tuple(fields); + } else { + auto* tensor = sinfo.as(); + ICHECK(tensor) << "FuseTIR can only take tensor or tuple type"; + auto* shape_expr = tensor->shape.as(); + ICHECK(shape_expr) << "FuseTIR requires all intermediate values have shape"; + return GetRef(shape_expr); + } + } + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + Call call = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(op))); + + if (call->op->IsInstance()) { + // Case 1. It is a relax cross function call + GlobalVar old_gv = Downcast(call->op); + auto it = fused_tir_funcs_.find(old_gv); + if (it != fused_tir_funcs_.end()) { + const tir::PrimFunc& fused_tir = (*it).second; + // Case 1.1. It calls a primitive relax function, update the call into a call_tir + GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint); + // Step a. Flatten all args since call_tir does not support Tuple value. + Array arg_list; + for (const Expr& arg : call->args) { + Array flattened = FlattenArg(arg); + arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); + } + // Step b. Create call_tir + Array call_args = {fused_tir_gv, Tuple(arg_list)}; + return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)}); + } else { + // Case 1.2. The callee function is not primitive, nothing to do. + return call; + } + } else if (call->op == call_tir_op_) { + // Case 2. It is a call_tir, re-emit the PrimFunc. + GlobalVar gv = Downcast(call->args[0]); + tir::PrimFunc func = Downcast(mod_->Lookup(gv)); + GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); + return Call(call->op, {new_gv, call->args[1]}, call->attrs, call->sinfo_args, call->span); + } else { + // Case 3. CallNode in other types. Leave it as it is. + return call; + } + } + + /********** Helper Functions **********/ + + /*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */ + Array FlattenArg(const Expr& arg) { + if (const auto* tuple_sinfo = GetStructInfoAs(arg)) { + Array arg_list; + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + Expr new_arg = builder_->Emit(TupleGetItem(arg, i)); + Array flattened = FlattenArg(new_arg); + arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); + } + return arg_list; + } else { + return {arg}; + } + } + + private: + /*! \brief The IRModule */ + const IRModule& mod_; + /*! \brief The map from global var of primitive relax function to generated prim func. */ + Map fused_tir_funcs_; +}; + +IRModule FuseTIR(IRModule mod) { + mod = TIRFuseMutator::Transform(mod); + return mod; +} + +namespace transform { + +Pass FuseTIR() { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { return relax::FuseTIR(m); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseTIR", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py b/tests/python/relax/test_transform_annotate_tir_op_pattern.py new file mode 100644 index 0000000000000..73c65378693ad --- /dev/null +++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py @@ -0,0 +1,360 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import enum + +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import tir as T + + +class OpPatternKind(enum.IntEnum): + kElemWise = 0 + kBroadcast = 1 + kInjective = 2 + kCommReduce = 3 + kOutEWiseFusable = 4 + kTuple = 7 + kOpaque = 8 + + +def test_annotate_opkind_outewisefusable(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_outewisefusable_int_var_signature(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle, m: T.int64, n: T.int64, k: T.int64): + T.func_attr({"global_symbol": "tir_matmul"}) + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_reduce(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def sum(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16,)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SR", [i, j]) + with T.init(): + B[vi] = 0.0 + B[vi] += A[vi, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["sum"].attrs["op_pattern"] == OpPatternKind.kCommReduce + + +def test_annotate_opkind_ewise(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def elemwise(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["elemwise"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_broadcast(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def broadcast(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16, 16, 16)) + + for i0, j0, i1, j1 in T.grid(16, 16, 16, 16): + with T.block("matmul"): + vi0, vj0, vi1, vj1 = T.axis.remap("SSSS", [i0, j0, i1, j1]) + B[vi0, vj0, vi1, vj1] = A[vj0, vj1] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["broadcast"].attrs["op_pattern"] == OpPatternKind.kBroadcast + + +def test_annotate_opkind_injective(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def injective(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (4, 4, 4, 4)) + B = T.match_buffer(y, (16, 16)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi // 4, vj // 4, vi % 4, vj % 4] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["injective"].attrs["op_pattern"] == OpPatternKind.kInjective + + +def test_annotate_opkind_bias_add(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_bias_add( + A: T.Buffer((1, 1000), "float32"), + B: T.Buffer((1000,), "float32"), + C: T.Buffer((1, 1000), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1 in T.grid(1, 1000): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax1] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_bias_add"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_add_broadcast_with_unit_shape(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def add_with_unit_dim_len_broadcast( + A: T.Buffer((1, 64, 112, 112), "float32"), + B: T.Buffer((64, 1, 1), "float32"), + C: T.Buffer((1, 64, 112, 112), "float32"), + ) -> None: + T.func_attr({"global_symbol": "add5", "tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(1, 64, 112, 112): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, ax2, ax3], B[ax1, 0, 0]) + T.writes(C[ax0, ax1, ax2, ax3]) + C[ax0, ax1, ax2, ax3] = A[ax0, ax1, ax2, ax3] + B[ax1, 0, 0] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["add_with_unit_dim_len_broadcast"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_add_zero_dim_element_wise(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def add_zero_dim( + A: T.Buffer((128,), "float32"), + B: T.Buffer((), "float32"), + C: T.Buffer((128,), "float32"), + ) -> None: + T.func_attr({"global_symbol": "add8", "tir.noalias": True}) + for i0 in T.serial(128): + with T.block("T_add"): + ax0 = T.axis.spatial(128, i0) + T.reads(A[ax0], B[()]) + T.writes(C[ax0]) + C[ax0] = A[ax0] + B[()] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["add_zero_dim"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_pooling(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def max_pool2d( + rxplaceholder_1: T.Buffer((1, 64, 112, 112), "float32"), + tensor_1: T.Buffer((1, 64, 56, 56), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "max_pool2d", "T.noalias": True}) + # body + # with T.block("root") + pad_temp_1 = T.alloc_buffer([1, 64, 114, 114], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 64, 114, 114): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1]) + T.writes(pad_temp_1[ax0, ax1, ax2, ax3]) + pad_temp_1[ax0, ax1, ax2, ax3] = T.if_then_else( + 1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113, + rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1], + T.float32(-3.4028234663852886e38), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 64, 56, 56, 3, 3): + with T.block("tensor"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads( + tensor_1[ax0, ax1, ax2, ax3], + pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1], + ) + T.writes(tensor_1[ax0, ax1, ax2, ax3]) + with T.init(): + tensor_1[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e38) + tensor_1[ax0, ax1, ax2, ax3] = T.max( + tensor_1[ax0, ax1, ax2, ax3], + pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1], + ) + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["max_pool2d"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_softmax(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def softmax( + rxplaceholder_1: T.Buffer((16, 16), "float32"), + T_softmax_norm_1: T.Buffer((16, 16), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "softmax", "T.noalias": True}) + # body + # with T.block("root") + T_softmax_maxelem_1 = T.alloc_buffer([16], dtype="float32") + T_softmax_exp_1 = T.alloc_buffer([16, 16], dtype="float32") + T_softmax_expsum_1 = T.alloc_buffer([16], dtype="float32") + for i0_7, i1_3 in T.grid(16, 16): + with T.block("T_softmax_maxelem"): + i0_8, k = T.axis.remap("SR", [i0_7, i1_3]) + T.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]) + T.writes(T_softmax_maxelem_1[i0_8]) + with T.init(): + T_softmax_maxelem_1[i0_8] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem_1[i0_8] = T.max( + T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k] + ) + for i0_9, i1_4 in T.grid(16, 16): + with T.block("T_softmax_exp"): + i0_10, i1_5 = T.axis.remap("SS", [i0_9, i1_4]) + T.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10]) + T.writes(T_softmax_exp_1[i0_10, i1_5]) + T_softmax_exp_1[i0_10, i1_5] = T.exp( + rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32" + ) + for i0_11, i1_6 in T.grid(16, 16): + with T.block("T_softmax_expsum"): + i0_12, k = T.axis.remap("SR", [i0_11, i1_6]) + T.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k]) + T.writes(T_softmax_expsum_1[i0_12]) + with T.init(): + T_softmax_expsum_1[i0_12] = T.float32(0) + T_softmax_expsum_1[i0_12] = ( + T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k] + ) + for i0_13, i1_7 in T.grid(16, 16): + with T.block("T_softmax_norm"): + i0_14, i1_8 = T.axis.remap("SS", [i0_13, i1_7]) + T.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14]) + T.writes(T_softmax_norm_1[i0_14, i1_8]) + T.block_attr({"axis": 1}) + T_softmax_norm_1[i0_14, i1_8] = ( + T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14] + ) + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["softmax"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_multiple_bufer_stores_fallback(): + @tvm.script.ir_module + class CumsumModule: + @T.prim_func + def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer(160, "float32")): + rxplaceholder = T.match_buffer( + var_rxplaceholder, [10, 16], dtype="float32", offset_factor=1 + ) + with T.block("cumsum_generic"): + T.reads(rxplaceholder[0:10, 0:16]) + T.writes(out_buf[0:160]) + for fused in T.parallel(1): + out_buf[fused * 160] = rxplaceholder[fused * 160 // 16, fused * 160 % 16] + for v_k in T.serial(159): + out_buf[fused * 160 + (v_k + 1)] = ( + out_buf[fused * 160 + (v_k + 1 - 1)] + + rxplaceholder[ + (fused * 160 + (v_k + 1)) // 16, + (fused * 160 + (v_k + 1)) % 16, + ] + ) + + mod = CumsumModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["cumsum"].attrs["op_pattern"] == OpPatternKind.kOpaque + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py new file mode 100644 index 0000000000000..1a228bb268fab --- /dev/null +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -0,0 +1,759 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax, topi +from tvm.script import relax as R + + +def _check(mod_actual, mod_expected): + mod_actual = relax.transform.AnnotateTIROpPattern()(mod_actual) + mod_actual = relax.transform.FuseOps()(mod_actual) + mod_expected = relax.transform.AnnotateTIROpPattern()(mod_expected) + tvm.ir.assert_structural_equal(mod_actual, mod_expected) + + +def test_fuse_simple(): + """Simple testcase.""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_conv2d_fuse(): + """Test fusion case of conv2d""" + + def before(dtype): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit_te(topi.nn.conv2d, lv0, w1, strides=1, padding=1, dilation=1) + # this is the next dominator. + lv2 = bb.emit_te(topi.add, relax.const(1, dtype), lv1) + lv3 = bb.emit_te(topi.add, lv1, lv2) + # second path + lv4 = bb.emit_te(topi.nn.conv2d, lv3, w2, strides=1, padding=0, dilation=1) + lv5 = bb.emit_te(topi.nn.conv2d, lv3, w3, strides=1, padding=1, dilation=1) + gv = bb.emit_output(bb.call_te(topi.add, lv4, lv5)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dtype): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype)) + p0 = relax.Var("p0", R.Tensor((), dtype)) + with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1") + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype)) + y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype)) + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=0, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2") + fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)])) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, lv2])) + bb.emit_func_output(gv) + + return bb.get() + + _check(before("float32"), expected("float32")) + _check(before("float16"), expected("float16")) + _check(before("int8"), expected("int8")) + + +def test_concatenate(): + """Test fusion case involving concat op and Tuple node""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, scale_w=2.0) + lv2 = bb.emit_te(topi.concatenate, (lv1, x), axis=1) + gv = bb.emit_output(bb.call_te(topi.add, lv2, relax.const(1, "float32"))) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w = relax.Var("w", R.Tensor((1, 16, 32, 32), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_upsampling_concatenate_add", [w, x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0, scale_w=2.0) + lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1) + gv = bb.emit_output(bb.call_te(topi.add, lv1, p0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_upsampling_concatenate_add = bb.get().get_global_var( + "fused_upsampling_concatenate_add" + ) + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + gv = bb.emit_output( + relax.Call( + fused_upsampling_concatenate_add, (lv0, x, relax.const(1, "float32")) + ) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_tuple_root(): + """Test fusion case where Tuple node is the root in its group""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, scale_w=2.0) + gv = bb.emit_output((lv1, x)) + bb.emit_func_output(gv) + + return bb.get() + + # The fusion is supposed to make no change. + _check(before(), before()) + + +def test_fuse_tuple_get_elemwise(): + def before(dim: int): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, dim), "float32")) + w = relax.Var("w", R.Tensor((3 * dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.dense, x, w) + lv1 = bb.emit_te(topi.split, lv0, indices_or_sections=3, axis=1) + lv2 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv3 = bb.emit_te(topi.sigmoid, lv2) + lv4 = bb.emit(relax.TupleGetItem(lv1, 1)) + lv5 = bb.emit_te(topi.tanh, lv4) + lv6 = bb.emit(relax.TupleGetItem(lv1, 2)) + lv7 = bb.emit_te(topi.exp, lv6) + lv8 = bb.emit_te(topi.multiply, lv5, lv7) + gv = bb.emit_output(bb.call_te(topi.add, lv3, lv8)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dim: int): + bb = relax.BlockBuilder() + + # Grouped function + dense = relax.Var("dense", R.Tensor((1, 3 * dim), "float32")) + with bb.function( + "fused_split_sigmoid_tanh_exp_multiply_add", [dense], attrs={"Primitive": 1} + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3, axis=1) + lv1 = bb.emit(relax.TupleGetItem(lv0, 0)) + lv2 = bb.emit_te(topi.sigmoid, lv1) + lv3 = bb.emit(relax.TupleGetItem(lv0, 1)) + lv4 = bb.emit_te(topi.tanh, lv3) + lv5 = bb.emit(relax.TupleGetItem(lv0, 2)) + lv6 = bb.emit_te(topi.exp, lv5) + lv7 = bb.emit_te(topi.multiply, lv4, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv2, lv7)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_split_sigmoid_tanh_exp_multiply_add = bb.get().get_global_var( + "fused_split_sigmoid_tanh_exp_multiply_add" + ) + + # Main function + x = relax.Var("x", R.Tensor((1, dim), "float32")) + w = relax.Var("w", R.Tensor((3 * dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.dense, x, w) + gv = bb.emit_output(relax.Call(fused_split_sigmoid_tanh_exp_multiply_add, (lv0,))) + bb.emit_func_output(gv) + + return bb.get() + + dim = 10 + _check(before(dim), expected(dim)) + + +def test_tuple_get_root(): + def before(dim: int): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) + w = relax.Var("w", R.Tensor((dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) + lv1 = bb.emit(relax.TupleGetItem(lv0, 0)) + gv = bb.emit_output(bb.call_te(topi.nn.dense, lv1, w)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dim: int): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) + with bb.function("fused_split", [x], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) + gv = bb.emit_output(relax.TupleGetItem(lv0, 0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_split = bb.get().get_global_var("fused_split") + + # Main function + x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) + w = relax.Var("w", R.Tensor((dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(fused_split, (x,))) + gv = bb.emit_output(bb.call_te(topi.nn.dense, lv0, w)) + bb.emit_func_output(gv) + + return bb.get() + + dim = 10 + _check(before(dim), expected(dim)) + + +def test_tuple_intermediate(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.squeeze, x) + lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + lv2 = bb.emit_te(topi.squeeze, lv0) + lv3 = bb.emit_te(topi.add, lv2, relax.const(1, "float32")) + lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32")) + lv5 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1) + lv7 = bb.emit_te(topi.squeeze, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv7, relax.const(1, "float32"))) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + p1 = relax.Var("p1", R.Tensor((), "float32")) + p2 = relax.Var("p2", R.Tensor((), "float32")) + p3 = relax.Var("p3", R.Tensor((), "float32")) + p4 = relax.Var("p4", R.Tensor((), "float32")) + with bb.function( + "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1", + [x, p0, p1, p2, p3, p4], + attrs={"Primitive": 1}, + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.squeeze, x) + lv1 = bb.emit_te(topi.add, lv0, p0) + lv2 = bb.emit_te(topi.squeeze, lv0) + lv3 = bb.emit_te(topi.add, lv2, p1) + lv4 = bb.emit_te(topi.add, lv3, p2) + lv5 = bb.emit_te(topi.add, lv0, p3) + lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1) + lv7 = bb.emit_te(topi.squeeze, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv7, p4)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var( + "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1" + ) + + # Main func + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call( + fused_func, + ( + x, + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + ), + ) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_tuple_consecutive(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv2 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1) + lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32")) + lv5 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv6 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv7 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1) + lv9 = bb.emit_te(topi.add, lv8, relax.const(1, "float32")) + lv10 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv11 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv12 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1) + lv14 = bb.emit_te(topi.add, lv13, relax.const(1, "float32")) + lv15 = bb.emit_te(topi.concatenate, (lv4, lv9, lv14), axis=1) + lv16 = bb.emit_te( + topi.nn.pool2d, + lv15, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv17 = bb.emit_te(topi.add, lv16, relax.const(1, "float32")) + lv18 = bb.emit_te(topi.add, lv17, relax.const(1, "float32")) + gv = bb.emit_output((lv17, lv18)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + p1 = relax.Var("p1", R.Tensor((), "float32")) + p2 = relax.Var("p2", R.Tensor((), "float32")) + p3 = relax.Var("p3", R.Tensor((), "float32")) + p4 = relax.Var("p4", R.Tensor((), "float32")) + p5 = relax.Var("p5", R.Tensor((), "float32")) + p6 = relax.Var("p6", R.Tensor((), "float32")) + p7 = relax.Var("p7", R.Tensor((), "float32")) + p8 = relax.Var("p8", R.Tensor((), "float32")) + p9 = relax.Var("p9", R.Tensor((), "float32")) + p10 = relax.Var("p10", R.Tensor((), "float32")) + p11 = relax.Var("p11", R.Tensor((), "float32")) + with bb.function( + "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1", + [x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11], + attrs={"Primitive": 1}, + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.add, x, p1) + lv2 = bb.emit_te(topi.add, x, p2) + lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1) + lv4 = bb.emit_te(topi.add, lv3, p3) + lv5 = bb.emit_te(topi.add, x, p4) + lv6 = bb.emit_te(topi.add, x, p5) + lv7 = bb.emit_te(topi.add, x, p6) + lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1) + lv9 = bb.emit_te(topi.add, lv8, p7) + lv10 = bb.emit_te(topi.add, x, p8) + lv11 = bb.emit_te(topi.add, x, p9) + lv12 = bb.emit_te(topi.add, x, p10) + lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1) + lv14 = bb.emit_te(topi.add, lv13, p11) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv4, lv9, lv14), axis=1)) + bb.emit_func_output(gv) + + # Grouped function 2 + concat = relax.Var("concat", R.Tensor((1, 144, 64, 64), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_pool2d_add2", [concat, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + concat, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, p0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_func1 = mod.get_global_var( + "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1" + ) + fused_func2 = mod.get_global_var("fused_pool2d_add2") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit( + relax.Call( + fused_func1, + ( + x, + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + ), + ) + ) + lv1 = bb.emit(relax.Call(fused_func2, (lv0, relax.const(1, "float32")))) + lv2 = bb.emit_te(topi.add, lv1, relax.const(1, "float32")) + gv = bb.emit_output((lv1, lv2)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_inception_like(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32")) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32")) + w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32")) + w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32")) + with bb.function("main", [x, w0, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.conv2d, x, w0, strides=1, padding=1, dilation=1) + lv1 = bb.emit_te(topi.nn.relu, lv0) + lv2 = bb.emit_te(topi.nn.conv2d, x, w1, strides=1, padding=1, dilation=1) + lv3 = bb.emit_te(topi.nn.relu, lv2) + lv4 = bb.emit_te(topi.concatenate, (lv1, lv3), axis=1) + lv5 = bb.emit_te(topi.nn.conv2d, lv4, w2, strides=1, padding=1, dilation=1) + lv6 = bb.emit_te(topi.nn.relu, lv5) + lv7 = bb.emit_te(topi.nn.conv2d, lv4, w3, strides=1, padding=1, dilation=1) + lv8 = bb.emit_te(topi.nn.relu, lv7) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv6, lv8), axis=1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w = relax.Var("w", R.Tensor((16, 16, 3, 3), "float32")) + with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0)) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", R.Tensor((1, 32, 64, 64), "float32")) + w = relax.Var("w", R.Tensor((16, 32, 3, 3), "float32")) + with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_relu1 = mod.get_global_var("fused_conv2d_relu") + fused_conv2d_relu2 = mod.get_global_var("fused_conv2d1_relu") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32")) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32")) + w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32")) + w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32")) + with bb.function("main", [x, w0, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w0))) + lv1 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w1))) + lv2 = bb.emit_te(topi.concatenate, (lv0, lv1), axis=1) + lv3 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w2))) + lv4 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w3))) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv3, lv4), axis=1)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_fuse_parallel_injective(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((10, 20), "int32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "int32")) + lv1 = bb.emit_te(topi.squeeze, lv0) + lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0]) + lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0]) + gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((10, 20), "int32")) + p0 = relax.Var("p0", R.Tensor((), "int32")) + with bb.function( + "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0], attrs={"Primitive": 1} + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.squeeze, lv0) + lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0]) + lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0], primfunc_name_hint="transpose1") + gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var("fused_add_squeeze_transpose_transpose1_left_shift") + + # Main function + x = relax.Var("x", R.Tensor((10, 20), "int32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_func, (x, relax.const(1, "int32")))) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_softmax(): + """Test if softmax can be fused with following ops.""" + + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((16, 16), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.softmax, x) + gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((16, 16), "float32")) + with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.softmax, x) + gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var("fused_softmax_cast") + + # Main function + x = relax.Var("x", R.Tensor((16, 16), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_func, (x,))) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py new file mode 100644 index 0000000000000..91edab2bbb984 --- /dev/null +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -0,0 +1,563 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax, topi +from tvm.script import relax as R + + +def _check(mod_before, mod_expected): + mod = relax.transform.FuseTIR()(mod_before) + tvm.ir.assert_structural_equal(mod, mod_expected) + + +def test_simple(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + def fused_add_exp_squeeze(x, p0): + add = topi.add(x, p0) + exp = topi.exp(add) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_conv2d_fuse(): + def before(dtype): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype)) + p0 = relax.Var("p0", R.Tensor((), dtype)) + with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1") + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype)) + y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype)) + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=0, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2") + fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)])) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, lv2])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dtype): + def fused_conv2d_add1_add2(x, w, p): + conv = topi.nn.conv2d(x, w, strides=1, padding=1, dilation=1) + add = topi.add(p, conv) + return topi.add(conv, add) + + def fused_conv2d1_add2(x, w, p): + conv = topi.nn.conv2d(x, w, strides=1, padding=0, dilation=1) + return topi.add(conv, p) + + bb = relax.BlockBuilder() + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit_te(fused_conv2d_add1_add2, lv0, w1, relax.const(1, dtype)) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(bb.call_te(fused_conv2d1_add2, lv1, w2, lv2)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before("float32"), expected("float32")) + + +def test_two_subfunction(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv1 = bb.emit_te(topi.exp, x1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x])) + lv2 = bb.emit(relax.Call(func_gv, [lv])) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_squeeze(x): + exp = topi.exp(x) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_squeeze, x) + lv2 = bb.emit_te(fused_exp_squeeze, lv) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_same_primfunc(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + with bb.function("fused_exp_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv1 = bb.emit_te(topi.exp, x1) + lv2 = bb.emit_te(topi.exp, lv1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv2)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x])) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_exp_squeeze(x): + exp = topi.exp(x) + exp = topi.exp(exp) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_exp_squeeze, x) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_tuple_as_param(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) + with bb.function("fused_exp_add", [x], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv2 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.add, lv2, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_add") + x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_add(x1, x2): + exp = topi.exp(x1) + return topi.add(exp, x2) + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + gv = bb.emit_output(bb.call_te(fused_exp_add, lv0, lv1)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_nested_tuple_as_param(): + tuple_struct_info = R.Tuple( + [R.Tensor([10], "float32"), R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])] + ) + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", tuple_struct_info) + with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv0_exp = bb.emit_te(topi.exp, lv0) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv1_0 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv1_1 = bb.emit(relax.TupleGetItem(lv1, 1)) + lv2 = bb.emit_te(topi.add, lv1_0, lv1_1) + gv = bb.emit_output(bb.call_te(topi.add, lv0_exp, lv2)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_add_add") + x = relax.Var("x", tuple_struct_info) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_add_add(x1, x2, x3): + exp = topi.exp(x1) + add = topi.add(x2, x3) + return topi.add(exp, add) + + bb = relax.BlockBuilder() + x = relax.Var("x", tuple_struct_info) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv2 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv3 = bb.emit(relax.TupleGetItem(lv1, 1)) + gv = bb.emit_output(bb.call_te(fused_exp_add_add, lv0, lv2, lv3)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_call_tir_in_main(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv = bb.emit_te(topi.exp, x1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(func_gv, [x])) + lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + gv = bb.emit_output(lv1) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_squeeze(x): + exp = topi.exp(x) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_squeeze, x) + lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32")) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_const_in_argument(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + x2 = relax.Var("x2", R.Tensor([], "float32")) + with bb.function("fused_add_exp_squeeze", [x1, x2], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x1, x2) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_add_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x, relax.const(1, "float32")])) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_add_exp_squeeze(x, y): + add = topi.add(x, y) + exp = topi.exp(add) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, "float32")) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_tuple_output(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + + with bb.function("fused_add_exp", [x, p0], attrs={"Primitive": True}): + with bb.dataflow(): + gv0 = bb.emit_output(bb.call_te(topi.add, x, p0)) + gv1 = bb.emit_output(bb.call_te(topi.exp, gv0)) + bb.emit_func_output(relax.Tuple([gv0, gv1])) + fused_add_exp = bb.get().get_global_var("fused_add_exp") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + def fused_add_exp(x, p0): + add = topi.add(x, p0) + exp = topi.exp(add) + return add, exp + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_add_exp, x, p0)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_immediate_tuple(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + y = relax.Var("y", R.Tensor([10, 20], "float32")) + + with bb.function("fused_add", [x, y], attrs={"Primitive": True}): + with bb.dataflow(): + lv_tuple = bb.emit(relax.Tuple([x, relax.Tuple([x, y])])) + lv_x = bb.emit(relax.TupleGetItem(lv_tuple, 0)) + lv0 = bb.emit(relax.TupleGetItem(lv_tuple, 1)) + lv_y = bb.emit(relax.TupleGetItem(lv0, 1)) + gv = bb.emit_output(bb.call_te(topi.add, lv_x, lv_y)) + bb.emit_func_output(gv) + fused_add = bb.get().get_global_var("fused_add") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + y = relax.Var("y", R.Tensor([10, 20], "float32")) + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add, [x, y])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + y = relax.Var("y", R.Tensor([10, 20], "float32")) + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(topi.add, x, y, primfunc_name_hint="fused_add")) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_return_partial_result(): + def te_argmax_idx_val(val): + from tvm import te + + def f_combine(x, y): + lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + + def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): + return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1) + + argmax = te.comm_reducer(f_combine, f_identity, name="argmax") + m, n = val.shape + k = te.reduce_axis((0, n), "k") + max_idx, max_val = te.compute( + (m,), lambda i: argmax((k.var, val[i, k]), axis=k), name="argmax" + ) + return max_idx, max_val + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + offset = relax.Var("offset", R.Tensor([10], "int32")) + with bb.function("fused_argmax_add", [x, offset], attrs={"Primitive": True}): + with bb.dataflow(): + lv = bb.emit_te(te_argmax_idx_val, x) + idx = bb.emit(relax.TupleGetItem(lv, 0)) + gv = bb.emit_output(bb.call_te(topi.add, idx, offset)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_argmax_add") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + offset = relax.Var("x", R.Tensor([10], "int32")) + with bb.function("main", [x, offset]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x, offset])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_argmax_add(x, offset): + idx, value = te_argmax_idx_val(x) + idx = topi.add(idx, offset) + return idx + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + offset = relax.Var("offset", R.Tensor([10], "int32")) + with bb.function("main", [x, offset]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_argmax_add, x, offset)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index f6d2e4c20e48e..6e9e14d3dc470 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1073,5 +1073,4 @@ def mul_add(x: R.Tensor) -> R.Tensor: if __name__ == "__main__": - test_cross_function_call() tvm.testing.main()