diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5a7b85ac1376..eaad44a93ace 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -253,6 +253,27 @@ TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_war */ TVM_DLL Pass RealizeVDevice(); +/*! + * \brief Attach layout free buffers to the tir::PrimFunc. + * + * This pass is used to attach layout free buffers to the tir::PrimFunc according to + * the function usage in the relax function. Currently, the layout free buffers are the model + * weights and relax constants. + * + * \note We recommend applying CanonicalizeBindings before this pass. + * \return The Pass. + */ +TVM_DLL Pass AttachAttrLayoutFreeBuffers(); + +/*! + * \brief Split the layout rewrite preproc block to a separate tir::PrimFunc. + * + * This pass is used in the prepack weight after meta_schedule tuning. + * + * \return The Pass. + */ +TVM_DLL Pass SplitLayoutRewritePreproc(); + /*! * \brief Lift transformation of the parameters of a function. * diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py index a8200d8dd627..f490af7062b0 100644 --- a/python/tvm/relax/frontend/nn/__init__.py +++ b/python/tvm/relax/frontend/nn/__init__.py @@ -23,6 +23,8 @@ from .modules import ( GELU, Conv1D, + Conv2D, + Conv3D, ConvTranspose1D, Embedding, GroupNorm, diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 582f5111aaf5..fe3dbc99fc15 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -109,6 +109,7 @@ def static_shape_tuning_pipeline( total_trials: int, target: Union[str, tvm.target.Target], work_dir: str = "tuning_logs", + cpu_weight_prepack: bool = False, ): """Tune the static shape model and store the log to database. @@ -122,18 +123,65 @@ def static_shape_tuning_pipeline( work_dir : str The directory to store the tuning logs. + + cpu_weight_prepack : bool + Whether to enable the cpu weight prepack feature. + + Note + ---- + `cpu_weight_prepack` is expected to be `True` when running on CPU for + better performance. However, it requires an explicit layout transformation + step by calling the corresponding vm function, which changes the interface + of deployment. So we disable it by default. Here is an example to enable it: + + .. code-block:: python + + mod = relax.pipeline.static_shape_tuning_pipeline( + total_trials=1000, + target="llvm -num-cores 16", + work_dir="tuning_logs", + cpu_weight_prepack=True, + )(mod) + + ex = relax.build(mod, target=target) + vm = relax.VirtualMachine(ex, device=tvm.cpu()) + + # Transform the params using the vm function + # the name should be f"{func_name}_transform_params" + params = vm["main_transform_params"](params["main"]) + + input_data = tvm.nd.array(np.random.randn(1, 3, 224, 224).astype("float32")) + out = vm["main"](input_data, *params).numpy() """ @tvm.transform.module_pass(opt_level=0) def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + if cpu_weight_prepack: + pre_tuning_layout_rewrite = [transform.AttachAttrLayoutFreeBuffers()] + post_tuning_layout_rewrite = [ + transform.SplitLayoutRewritePreproc(), + transform.LiftTransformParams(), + transform.FoldConstant(), + ] + else: + pre_tuning_layout_rewrite = [] + post_tuning_layout_rewrite = [] + with tvm.target.Target(target): mod = tvm.transform.Sequential( [ transform.DecomposeOpsForInference(), transform.CanonicalizeBindings(), zero_pipeline(), - transform.MetaScheduleTuneIRMod({}, work_dir, total_trials), + *pre_tuning_layout_rewrite, + # Skip tuning if total_trials is 0 + ( + transform.MetaScheduleTuneIRMod({}, work_dir, total_trials) + if total_trials > 0 + else tvm.transform.Sequential([]) + ), transform.MetaScheduleApplyDatabase(work_dir), + *post_tuning_layout_rewrite, ] )(mod) diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 1ce864651cd9..16e4800ca33d 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -21,6 +21,7 @@ AllocateWorkspace, AlterOpImpl, AnnotateTIROpPattern, + AttachAttrLayoutFreeBuffers, AttachGlobalSymbol, BindParams, BindSymbolicVars, @@ -73,6 +74,7 @@ RewriteDataflowReshape, RunCodegen, SplitCallTIRByPattern, + SplitLayoutRewritePreproc, StaticPlanBlockMemory, ToMixedPrecision, ToNonDataflow, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 3330d4098734..603211b59ebc 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -970,6 +970,35 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass: return _ffi_api.MergeCompositeFunctions() # type: ignore +def AttachAttrLayoutFreeBuffers() -> tvm.ir.transform.Pass: + """Attach layout free buffers to the tir::PrimFunc. + + This pass is used to attach layout free buffers to the tir::PrimFunc according to + the function usage in the relax function. Currently, the layout free buffers are the model + weights and relax constants. + + Note that we recommend applying CanonicalizeBindings before this pass. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for attaching layout free buffers. + """ + return _ffi_api.AttachAttrLayoutFreeBuffers() # type: ignore + + +def SplitLayoutRewritePreproc() -> tvm.ir.transform.Pass: + """Split the TIR layout rewrite into multiple TIR functions. + This pass is used in the prepack weight after meta_schedule tuning. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for splitting TIR layout rewrite. + """ + return _ffi_api.SplitLayoutRewritePreproc() # type: ignore + + def LiftTransformParams(shared_transform: Union[bool, List[str]] = False) -> tvm.ir.transform.Pass: """Lift transformation of the parameters of a function. diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 71ae43387112..87fa96f67ceb 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -249,7 +249,13 @@ class RewriteLayoutNode : public PostprocNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from PostprocNode - bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); } + bool Apply(const tir::Schedule& sch) final { + try { + return tir::RewriteLayout(sch); + } catch (const std::runtime_error& e) { + return false; + } + } Postproc Clone() const { ObjectPtr n = make_object(*this); diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc new file mode 100644 index 000000000000..64062e224372 --- /dev/null +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/attach_attr_layout_free_buffers.cc + * \brief Attach layout_free_buffers for layout-free buffers. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +class AttrAttacher : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + AttrAttacher mutator(mod); + for (auto [gvar, func] : mod->functions) { + if (func->IsInstance()) { + // clear the layout_free_exprs_ for each function + mutator.layout_free_exprs_.clear(); + mutator.builder_->UpdateFunction(gvar, Downcast(mutator.VisitExpr(func))); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit AttrAttacher(IRModule mod) : ExprMutator(mod), mod_(mod) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const FunctionNode* op) final { + if (auto opt_num_input = op->attrs.GetAttr(attr::kNumInput)) { + ICHECK(layout_free_exprs_.empty()) << "meet a non-global function with num_input attr"; + size_t num_input = opt_num_input.value()->value; + for (size_t i = num_input; i < op->params.size(); i++) { + layout_free_exprs_.insert(op->params[i].get()); + } + } + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const ConstantNode* op) final { + layout_free_exprs_.insert(op); + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + Call call = Downcast(ExprMutator::VisitExpr_(op)); + if (call->op != call_tir_op_) { + return call; + } + GlobalVar gv = Downcast(call->args[0]); + Array call_tir_args = Downcast(call->args[1])->fields; + // Compute the layout free buffers + Array layout_free_buffers; + for (size_t i = 0; i < call_tir_args.size(); i++) { + if (layout_free_exprs_.count(call_tir_args[i].get())) { + layout_free_buffers.push_back(Integer(i)); + } + } + // Attach the layout free buffers to the tir::PrimFunc + tir::PrimFunc func = WithAttr(Downcast(mod_->Lookup(gv)), "layout_free_buffers", + layout_free_buffers); + // Renew defs + func = tir::RenewDefs(func); + // Add the updated tir::PrimFunc in the IRModule + // Note the blockbuilder would automatically combine the same tir function + // So we don't need to worry about the duplicate insertion + GlobalVar new_gv = builder_->AddFunction(func, gv->name_hint); + // Create a new call node with the updated tir::PrimFunc + auto n = make_object(*op); + n->args = {new_gv, Tuple(call_tir_args)}; + return Call(n); + } + + private: + IRModule mod_; + std::unordered_set layout_free_exprs_; +}; +namespace transform { + +Pass AttachAttrLayoutFreeBuffers() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return AttrAttacher::Transform(mod); }; + auto pass = CreateModulePass(pass_func, 0, "_AttachAttrLayoutFreeBuffers", {}); + // Apply DeadCodeElimination to remove unused tir::PrimFunc + return tvm::transform::Sequential({pass, DeadCodeElimination()}, "AttachAttrLayoutFreeBuffers"); +} + +TVM_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers") + .set_body_typed(AttachAttrLayoutFreeBuffers); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc new file mode 100644 index 000000000000..5fee946c26dd --- /dev/null +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/split_tir_layout_rewrite.cc + * \brief Use for rewriting the TIRs after meta_schedule layout rewrite post process. + */ +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tir { +class SplitPrimFuncLayoutRewrite : public StmtMutator { + public: + explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : original_func_(func) {} + std::tuple, PrimFunc> Transform(const PrimFunc& func) { + ICHECK(func->body.as()) << "The body of the primfunc should be a root block."; + const auto& block = func->body.as()->block; + visit_root_block(block.get()); + if (layout_rewrite_preproc_stmts_.size() > 0) { + return std::make_tuple(create_layout_rewrite_preproc_func(), create_compute_func()); + } else { + return std::make_tuple(NullOpt, func); + } + } + + private: + void sort_rewrite_infos() { + std::sort( + rewrite_infos_.begin(), rewrite_infos_.end(), + [](const RewriteInfo& a, const RewriteInfo& b) { return a.buffer_index < b.buffer_index; }); + } + + PrimFunc create_layout_rewrite_preproc_func() const { + // Step 1: Check the number of pre_rewrite_buffers and post_rewrite_buffers + ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer rewrite."; + + // Step 2: Create the params for the new PrimFunc + Array params; + Map buffer_map; + + for (const auto& info : rewrite_infos_) { + params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle())); + buffer_map.Set(params.back(), info.pre_rewrite_buffer); + } + for (const auto& info : rewrite_infos_) { + params.push_back(Var(info.post_rewrite_buffer->name, DataType::Handle())); + buffer_map.Set(params.back(), info.post_rewrite_buffer); + } + + // Step 3: Create the body for the new PrimFunc + ICHECK(layout_rewrite_preproc_stmts_.size() > 0) + << "There should be at least one layout rewrite preproc stmt."; + Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ? layout_rewrite_preproc_stmts_[0] + : SeqStmt(layout_rewrite_preproc_stmts_); + body = BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"root", body)); + + PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map); + + return RenewDefs(func); + } + + PrimFunc create_compute_func() const { + // Step 1: Create the params for the new PrimFunc + Array params = original_func_->params; + Map buffer_map = original_func_->buffer_map; + for (const auto& info : rewrite_infos_) { + const Var& param = params[info.buffer_index]; + ICHECK(buffer_map[param] == info.pre_rewrite_buffer); + buffer_map.Set(param, info.post_rewrite_buffer); + } + + // Step 2: Create the body for the new PrimFunc + Stmt body = compute_stmts_.size() == 1 ? compute_stmts_[0] : SeqStmt(compute_stmts_); + Block original_block = original_func_->body.as()->block; + Array alloc_buffers; + for (const auto& buffer : original_block->alloc_buffers) { + auto it = + std::find_if(rewrite_infos_.begin(), rewrite_infos_.end(), + [&](const RewriteInfo& info) { return info.post_rewrite_buffer == buffer; }); + if (it == rewrite_infos_.end()) { + alloc_buffers.push_back(buffer); + } + } + + body = BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"root", body, + /*init=*/NullOpt, + /*alloc_buffers=*/alloc_buffers)); + + PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), buffer_map); + return RenewDefs(func); + } + + void visit_root_block(const BlockNode* op) { + Stmt body = op->body; + if (const auto* seq_stmt = body.as()) { + for (const auto& stmt : seq_stmt->seq) { + current_subtree_ = 0; + Stmt new_stmt = this->VisitStmt(stmt); + ICHECK(current_subtree_ != 0) << "There should be at least a block in the subtree."; + if (current_subtree_ == 1) { + layout_rewrite_preproc_stmts_.push_back(new_stmt); + } else { + compute_stmts_.push_back(new_stmt); + } + } + } else { + current_subtree_ = 0; + this->VisitStmt(body); + ICHECK(current_subtree_ == -1) + << "There should be a compute block if there is only one subtree under the root."; + } + } + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtMutator::VisitStmt_(op)); + auto it = op->annotations.find(attr::meta_schedule_layout_rewrite_preproc); + bool is_layout_rewrite_preproc = + it != op->annotations.end() && is_one(Downcast((*it).second)); + + if (current_subtree_ == 0) { + current_subtree_ = is_layout_rewrite_preproc ? 1 : -1; + } else if (current_subtree_ == 1) { + CHECK(is_layout_rewrite_preproc) + << "There is a layout rewrite block in the subtree, but meet a non-layout rewrite block."; + } else { + CHECK(!is_layout_rewrite_preproc) + << "There is a non-layout rewrite block in the subtree, but meet a layout rewrite block."; + } + + if (is_layout_rewrite_preproc) { + ICHECK(op->reads.size() == 1) << "There should be only one read buffer in the layout rewrite"; + ICHECK(op->writes.size() == 1) + << "There should be only one write buffer in the layout rewrite"; + ICHECK(op->alloc_buffers.empty()) << "There should be no alloc buffer in the layout rewrite"; + ICHECK(op->match_buffers.empty()) << "There should be no match buffer in the layout rewrite"; + const Buffer& preproc_buffer = op->reads[0]->buffer; + int buffer_index = -1; + for (size_t i = 0; i < original_func_->params.size(); ++i) { + const Buffer& buffer = original_func_->buffer_map[original_func_->params[i]]; + if (buffer == preproc_buffer) { + buffer_index = i; + break; + } + } + ICHECK(buffer_index != -1) << "The preproc buffer is not found in the original primfunc."; + rewrite_infos_.push_back( + RewriteInfo{buffer_index, op->reads[0]->buffer, op->writes[0]->buffer}); + + auto new_annotations = op->annotations; + new_annotations.erase(attr::meta_schedule_layout_rewrite_preproc); + auto n = make_object(*block.get()); + n->annotations = new_annotations; + return Block(n); + } + return block; + } + + public: + struct RewriteInfo { + int buffer_index; + Buffer pre_rewrite_buffer; + Buffer post_rewrite_buffer; + }; + std::vector rewrite_infos_; + + private: + /*! \brief The stmts that are used for layout rewrite preproc*/ + Array layout_rewrite_preproc_stmts_; + /*! \brief The stmts that are other than layout rewrite preproc*/ + Array compute_stmts_; + /*! + \brief Whether the current subtree is a layout rewrite preproc subtree. + -1: visited a non-layout rewrite preproc block + 0: unsure, not visited any block + 1: visited a layout rewrite preproc block + */ + int current_subtree_; + /*! \brief The original primfunc*/ + PrimFunc original_func_; +}; +} // namespace tir + +namespace relax { +class SplitLayoutRewritePreproc : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + SplitLayoutRewritePreproc mutator(mod); + + // Step 1: Split the primfunc into preproc and compute + for (auto [gv, func] : mod->functions) { + if (func->IsInstance()) { + tir::SplitPrimFuncLayoutRewrite tir_rewriter(Downcast(func)); + auto [preproc_func, compute_func] = tir_rewriter.Transform(Downcast(func)); + if (preproc_func.defined()) { + mutator.split_funcs_.emplace(gv.get(), + std::make_tuple(preproc_func.value(), compute_func)); + mutator.rewrite_infos_.emplace(gv.get(), tir_rewriter.rewrite_infos_); + } + } + } + + for (auto [gv, func] : mod->functions) { + if (func->IsInstance()) { + auto relax_func = Downcast(func); + mutator.builder_->UpdateFunction(gv, Downcast(mutator(relax_func))); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit SplitLayoutRewritePreproc(const IRModule& mod) : ExprMutator(mod) {} + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + Call call = Downcast(ExprMutator::VisitExpr_(op)); + + // Step 1: Skip call to other than `tir.call_tir` + if (!call->op.same_as(call_tir_op)) { + return call; + } + + // Step 2: Skip if there is no preproc stage + const GlobalVar gv = Downcast(call->args[0]); + auto it = split_funcs_.find(gv.get()); + if (it == split_funcs_.end()) { + return call; + } + + // Step 3: Get the preproc and compute functions and update the module + const auto& [preproc_func, compute_func] = it->second; + GlobalVar preproc_gv = builder_->AddFunction(preproc_func, gv->name_hint + "_weight_prepack"); + GlobalVar compute_gv = builder_->AddFunction(compute_func, gv->name_hint + "_prepacked"); + // Step 4. Get rewrite infos + auto rewrite_infos_it = rewrite_infos_.find(gv.get()); + ICHECK(rewrite_infos_it != rewrite_infos_.end()) + << "Rewrite infos are not found for " << gv->name_hint; + const auto& rewrite_infos = rewrite_infos_it->second; + + // Step 5: Emit the preproc call + Array call_tir_args = Downcast(call->args[1])->fields; + Array preproc_args; + Array preproc_sinfo_list; + for (const auto& info : rewrite_infos) { + preproc_args.push_back(call_tir_args[info.buffer_index]); + tir::Buffer rewritten_buffer = info.post_rewrite_buffer; + for (const auto& shape_expr : rewritten_buffer->shape) { + CHECK(shape_expr.as()) << "Currently does not support rewrite buffer with " + "dynamic shape."; + } + preproc_sinfo_list.push_back( + TensorStructInfo(ShapeExpr(rewritten_buffer->shape), rewritten_buffer->dtype)); + } + StructInfo preproc_sinfo = preproc_sinfo_list.size() > 1 // + ? TupleStructInfo(preproc_sinfo_list) // + : preproc_sinfo_list[0]; + + // Step 6: Call the preproc function + Expr preproc_call = + builder_->Emit(Call(call_tir_op, {preproc_gv, Tuple(preproc_args)}, {}, {preproc_sinfo})); + if (rewrite_infos.size() == 1) { + call_tir_args.Set(rewrite_infos[0].buffer_index, preproc_call); + } else { + for (size_t i = 0; i < rewrite_infos.size(); ++i) { + call_tir_args.Set(rewrite_infos[i].buffer_index, TupleGetItem(preproc_call, i)); + } + } + Expr main_call = + builder_->Emit(Call(call_tir_op, {compute_gv, Tuple(call_tir_args)}, {}, call->sinfo_args)); + + return main_call; + } + + private: + std::unordered_map> split_funcs_; + std::unordered_map> + rewrite_infos_; +}; + +} // namespace relax + +namespace transform { +Pass SplitLayoutRewritePreproc() { + auto pass_func = [](IRModule mod, PassContext pc) { + return relax::SplitLayoutRewritePreproc::Transform(mod); + }; + auto pass = CreateModulePass(pass_func, 0, "SplitLayoutRewritePreproc", {}); + return tvm::transform::Sequential({pass, relax::transform::DeadCodeElimination()}, + "SplitLayoutRewritePreproc"); +} +TVM_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc") + .set_body_typed(SplitLayoutRewritePreproc); +} // namespace transform +} // namespace tvm diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py index e2305de2afaf..8348c57c1949 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py @@ -61,7 +61,8 @@ def inner(mod): ) sch = tvm.tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.space_generator.postprocs[0].apply(sch) + if not ctx.space_generator.postprocs[0].apply(sch): + raise tvm.TVMError("RewriteLayout postproc failed") return sch.mod return inner diff --git a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py new file mode 100644 index 000000000000..46f7c8aa87be --- /dev/null +++ b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py @@ -0,0 +1,311 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm.testing + +from tvm import relax, tir +from tvm.script import relax as R, tir as T, ir as I +from tvm.relax.transform import CombineParallelMatmul +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder + + +def test_param(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir(cls.matmul, (x, y), out_sinfo=R.Tensor((32, 32), "float32")) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv = R.call_tir(cls.matmul1, (x, y), out_sinfo=R.Tensor((32, 32), "float32")) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_const(): + const_value = np.ones((32, 32), dtype="float32") + + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.matmul, + (x, relax.const(const_value)), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv = R.call_tir( + cls.matmul1, + (x, relax.const(const_value)), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_multiple_same_func(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul, + (lv1, w2), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul1, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul1, + (lv1, w2), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_multiple_same_func_with_different_free_buffers(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul, + (w2, lv1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @T.prim_func(private=True) + def matmul2( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [0]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul1, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul2, + (w2, lv1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py new file mode 100644 index 000000000000..e6b4c8ec4e2a --- /dev/null +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing +from tvm import relax +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_single_buffer(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def tir_func( + X: T.Buffer((224, 224), "float32"), + W: T.Buffer((224, 224), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + W_rewrite = T.alloc_buffer((4, 4, 56, 56)) + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), dtype="float32") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def tir_func_prepacked( + X: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @T.prim_func(private=True) + def tir_func_weight_prepack( + W: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + ): + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = After + with R.dataflow(): + lv = R.call_tir( + cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 4, 56, 56), "float32") + ) + lv1 = R.call_tir( + cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 224), "float32") + ) + gv: R.Tensor((224, 224), dtype="float32") = lv1 + R.output(gv) + return gv + + mod = relax.transform.SplitLayoutRewritePreproc()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +def test_multiple_buffers(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def tir_func( + X: T.Buffer((224, 224), "float32"), + W1: T.Buffer((224, 224), "float32"), + W2: T.Buffer((224, 224), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + W1_rewrite = T.alloc_buffer((4, 4, 56, 56)) + W2_rewrite = T.alloc_buffer((4, 4, 56, 56)) + for i, j in T.grid(224, 224): + with T.block("W1_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj] + for i, j in T.grid(224, 224): + with T.block("W2_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj] + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = ( + X[vi, vj] + + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + ) + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w1: R.Tensor((224, 224), dtype="float32"), + w2: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.tir_func, (x, w1, w2), out_sinfo=R.Tensor((224, 224), dtype="float32") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def tir_func_prepacked( + X: T.Buffer((224, 224), "float32"), + W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = ( + X[vi, vj] + + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + ) + + @T.prim_func(private=True) + def tir_func_weight_prepack( + W1: T.Buffer((224, 224), "float32"), + W2: T.Buffer((224, 224), "float32"), + W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + ): + for i, j in T.grid(224, 224): + with T.block("W1_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj] + for i, j in T.grid(224, 224): + with T.block("W2_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w1: R.Tensor((224, 224), dtype="float32"), + w2: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = After + with R.dataflow(): + lv0 = R.call_tir( + cls.tir_func_weight_prepack, + (w1, w2), + out_sinfo=[ + R.Tensor((4, 4, 56, 56), "float32"), + R.Tensor((4, 4, 56, 56), "float32"), + ], + ) + lv1 = R.call_tir( + cls.tir_func_prepacked, + (x, lv0[0], lv0[1]), + out_sinfo=R.Tensor((224, 224), "float32"), + ) + gv: R.Tensor((224, 224), dtype="float32") = lv1 + R.output(gv) + return gv + + mod = relax.transform.SplitLayoutRewritePreproc()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +if __name__ == "__main__": + tvm.testing.main()