From d062e717b5c85dfc1163075d01b4bebc195e5fd2 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Fri, 18 Jun 2021 14:47:29 +0100 Subject: [PATCH] Add a pass to legalize packed calls Change-Id: I8aa43d3a1b837b03a5cf3c6b32fc760bd78d3436 --- src/relay/backend/aot_executor_codegen.cc | 45 +++---- src/tir/transforms/legalize_packed_calls.cc | 130 ++++++++++++++++++++ src/tir/transforms/storage_rewrite.cc | 42 +------ 3 files changed, 157 insertions(+), 60 deletions(-) create mode 100644 src/tir/transforms/legalize_packed_calls.cc diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 1dfa09ffcce9c..b8c9f50a53d0c 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -269,18 +269,18 @@ class AOTExecutorCodegen : public ExprVisitor { } auto sid_value = sids_table_[sid]; - if (!use_unpacked_api_) { - // Pack the sid inside the TVMValue - auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle()); - tvm::PrimExpr set_tensor = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrData, sid_value}); - stmts_.push_back( - tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); - buffer_vars.push_back(sid_array); - } else { - buffer_vars.push_back(sid_value); - } +// if (!use_unpacked_api_) { +// // Pack the sid inside the TVMValue +// auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle()); +// tvm::PrimExpr set_tensor = +// tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), +// {sid_array, 0, tir::builtin::kArrData, sid_value}); +// stmts_.push_back( +// tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); +// buffer_vars.push_back(sid_array); +// } else { + buffer_vars.push_back(sid_value); +// } } return buffer_vars; } @@ -300,15 +300,15 @@ class AOTExecutorCodegen : public ExprVisitor { auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), {tir::StringImm(params_by_expr_[expr])}); - if (!use_unpacked_api_) { - tvm::PrimExpr set_param_array = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {param_array, 0, tir::builtin::kArrData, param_handle}); - stmts_.push_back( - tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array))); - } else { - stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0))); - } +// if (!use_unpacked_api_) { +// tvm::PrimExpr set_param_array = +// tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), +// {param_array, 0, tir::builtin::kArrData, param_handle}); +// stmts_.push_back( +// tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array))); +// } else { + stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0))); +// } return param_array; } @@ -787,6 +787,9 @@ class AOTExecutorCodegen : public ExprVisitor { auto storage_rewrite = tir::transform::StorageRewrite(); mod_run = storage_rewrite(mod_run); + auto tir2runtime = tir::transform::MakePackedCalls(); + mod_run = tir2runtime(mod_run); + // Update the lowered functions auto target_host_str = target_host_->str(); if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc new file mode 100644 index 0000000000000..fb7c7691f7852 --- /dev/null +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file make_packed_call.cc + * \brief Rewrite packed calls in AOT so that the arguments are packed + */ +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Utility function to convert a concrete integer to a PrimExpr. + * \param num the number to convert + * \return PrimExpr representing num + */ +inline PrimExpr ConstInt32(size_t num) { + ICHECK_LE(num, std::numeric_limits::max()); + return tir::make_const(DataType::Int(32), static_cast(num)); +} + + +/*! + * \brief Utility function to allocate a DLTensor or TVMValue + * \param type the type of allocation + * \param num the number of variable to allocate on the stack + * \return PrimExpr representing the allocated object + */ +PrimExpr StackAlloca(std::string type, size_t num) { + Array args = {tir::StringImm(type), ConstInt32(num)}; + return tir::Call(DataType::Handle(), tir::builtin::tvm_stack_alloca(), args); +} + +using InputMap = std::unordered_map; +/** + * This is a legalization pass only used in AOT. Traverse the TIR graph to legalize + * packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in) + */ +class PackedCallLegalizer : public StmtExprMutator { +public: + Stmt Legalize(const InputMap& params, tir::Stmt body){ + inputs_ = params; + return StmtExprMutator::VisitStmt(body); + } + + Stmt VisitStmt_(const EvaluateNode *op) final{ + if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op); + const CallNode* call = op->value.as(); + // Given a packed call f(A,B,C), we need a set of new statements + // let A_packed = set_struct(tvm_value1, A) + // let B_packed = set_struct(tvm_value2, B) + // let C_packed = set_struct(tvm_value3, C) + // call_packed(f, A_packed, B_packed, C_packed) + std::vector new_stmts; + if (call) { + if (call->op.same_as(builtin::tvm_call_cpacked())){ + Array packed_args{call->args[0]}; + for (unsigned i = 1; iargs.size(); i++){ + // No need to pack inputs of the prim_func + if (inputs_[call->args[i]] == true){ + packed_args.push_back(call->args[i]); + }else { + // Pack the argument inside a TVMValue + auto sid_array = tir::Var("tvm_value", DataType::Handle()); + tir::Stmt set_struct_stmt = tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {sid_array, 0, tir::builtin::kArrData, call->args[i]})); + new_stmts.push_back(LetStmt(sid_array, StackAlloca("array", 1), set_struct_stmt)); + packed_args.push_back(sid_array); + } + } + // Finally, evaluate the packed call and return a sequential statement + new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args))); + return tir::SeqStmt(new_stmts); + } + } + return StmtExprMutator::VisitStmt_(op); + } + + +private: + InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed. +}; + +namespace transform { + +Pass LegalizePackedCalls(){ + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + + auto* n = f.CopyOnWrite(); + + // Create the + InputMap inputs; + for (auto i : f->params){ + inputs[i] = true; + } + n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body)); + return std::move(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); +} +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index cd91a4b53317d..36eeddb17d89b 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -138,35 +138,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); this->VisitExpr(l->index); - } else if (op->op.same_as(builtin::tvm_call_cpacked())) { - // Recall that the arguments of a tvm_call_cpacked are passed as - // TVMValues. But a TVMValue is only a container, that points to - // a real buffer previously allocated. We need to signal that those - // buffers need to be live at the same time (i.e., cannot be overwritten during the function - // call) - Array args = op->args; - for (auto arg : args) { - const VarNode* var = arg.as(); - if (value_to_alloc_.find(var) != value_to_alloc_.end()) { - auto allocs = value_to_alloc_[var]; - for (const VarNode* alloc : allocs) { - VisitExpr_(alloc); - } - } else { - this->VisitExpr(arg); - } - } - } else if (op->op.same_as(builtin::tvm_struct_set())) { - // If we are using a struct_set built-in, and we are setting - // a DLTensor ArrayData field, let's note down the - // buffers that the TVMValue refers to - const VarNode* var = op->args[0].as(); - const VarNode* alloc = op->args[3].as(); - const int field_id = op->args[2].as()->value; - if (var && alloc && field_id == tir::builtin::kArrData) { - value_to_alloc_[var].push_back(alloc); - } - StmtExprVisitor::VisitExpr_(op); } else { StmtExprVisitor::VisitExpr_(op); } @@ -235,13 +206,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { bool in_thread_env_{false}; // The scope stack. std::vector scope_; - // This is a map to connect TVMValues to real allocations. When we pass parameters - // to a tvm_call_cpacked, the data needs to be wrapped in a TVMValue. The wrapping - // happens through the tvm_struct_set built-in. This map is mapping the variable - // representing the TVMValue to the variable representing the real buffer. The live - // analysis needs to happen on the latter and not on the TVMValue which only acts as - // a container. - std::unordered_map> value_to_alloc_; }; // Verify if the statement can be run safely via inplace fashion @@ -923,11 +887,11 @@ class StoragePlanRewriter : public StmtExprMutator { // symbolic free list, for non constant items. std::list sym_free_list_; // The allocation attach map - std::unordered_map> attach_map_; + std::unordered_map > attach_map_; // The allocation assign map std::unordered_map alloc_map_; // The allocations - std::vector> alloc_vec_; + std::vector > alloc_vec_; // analyzer arith::Analyzer analyzer_; }; @@ -986,7 +950,7 @@ class VectorAllocRewriter : public StmtExprMutator { } // Internal access map - std::unordered_map> acc_map_; + std::unordered_map > acc_map_; // Variables to remap Map var_remap_; // internal analyzer