Skip to content

Commit

Permalink
Add a pass to legalize packed calls
Browse files Browse the repository at this point in the history
Change-Id: I8aa43d3a1b837b03a5cf3c6b32fc760bd78d3436
  • Loading branch information
Giuseppe Rossini committed Jun 21, 2021
1 parent 369745f commit d062e71
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 60 deletions.
45 changes: 24 additions & 21 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,18 +269,18 @@ class AOTExecutorCodegen : public ExprVisitor {
}

auto sid_value = sids_table_[sid];
if (!use_unpacked_api_) {
// Pack the sid inside the TVMValue
auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle());
tvm::PrimExpr set_tensor =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrData, sid_value});
stmts_.push_back(
tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));
buffer_vars.push_back(sid_array);
} else {
buffer_vars.push_back(sid_value);
}
// if (!use_unpacked_api_) {
// // Pack the sid inside the TVMValue
// auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle());
// tvm::PrimExpr set_tensor =
// tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
// {sid_array, 0, tir::builtin::kArrData, sid_value});
// stmts_.push_back(
// tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));
// buffer_vars.push_back(sid_array);
// } else {
buffer_vars.push_back(sid_value);
// }
}
return buffer_vars;
}
Expand All @@ -300,15 +300,15 @@ class AOTExecutorCodegen : public ExprVisitor {
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[expr])});

if (!use_unpacked_api_) {
tvm::PrimExpr set_param_array =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{param_array, 0, tir::builtin::kArrData, param_handle});
stmts_.push_back(
tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array)));
} else {
stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0)));
}
// if (!use_unpacked_api_) {
// tvm::PrimExpr set_param_array =
// tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
// {param_array, 0, tir::builtin::kArrData, param_handle});
// stmts_.push_back(
// tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array)));
// } else {
stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0)));
// }

return param_array;
}
Expand Down Expand Up @@ -787,6 +787,9 @@ class AOTExecutorCodegen : public ExprVisitor {
auto storage_rewrite = tir::transform::StorageRewrite();
mod_run = storage_rewrite(mod_run);

auto tir2runtime = tir::transform::MakePackedCalls();
mod_run = tir2runtime(mod_run);

// Update the lowered functions
auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
Expand Down
130 changes: 130 additions & 0 deletions src/tir/transforms/legalize_packed_calls.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file make_packed_call.cc
* \brief Rewrite packed calls in AOT so that the arguments are packed
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

#include <unordered_map>

namespace tvm {
namespace tir {

/*!
* \brief Utility function to convert a concrete integer to a PrimExpr.
* \param num the number to convert
* \return PrimExpr representing num
*/
inline PrimExpr ConstInt32(size_t num) {
ICHECK_LE(num, std::numeric_limits<int>::max());
return tir::make_const(DataType::Int(32), static_cast<int>(num));
}


/*!
* \brief Utility function to allocate a DLTensor or TVMValue
* \param type the type of allocation
* \param num the number of variable to allocate on the stack
* \return PrimExpr representing the allocated object
*/
PrimExpr StackAlloca(std::string type, size_t num) {
Array<PrimExpr> args = {tir::StringImm(type), ConstInt32(num)};
return tir::Call(DataType::Handle(), tir::builtin::tvm_stack_alloca(), args);
}

using InputMap = std::unordered_map<PrimExpr, bool, runtime::ObjectPtrHash,
runtime::ObjectPtrEqual>;
/**
* This is a legalization pass only used in AOT. Traverse the TIR graph to legalize
* packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in)
*/
class PackedCallLegalizer : public StmtExprMutator {
public:
Stmt Legalize(const InputMap& params, tir::Stmt body){
inputs_ = params;
return StmtExprMutator::VisitStmt(body);
}

Stmt VisitStmt_(const EvaluateNode *op) final{
if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op);
const CallNode* call = op->value.as<CallNode>();
// Given a packed call f(A,B,C), we need a set of new statements
// let A_packed = set_struct(tvm_value1, A)
// let B_packed = set_struct(tvm_value2, B)
// let C_packed = set_struct(tvm_value3, C)
// call_packed(f, A_packed, B_packed, C_packed)
std::vector<Stmt> new_stmts;
if (call) {
if (call->op.same_as(builtin::tvm_call_cpacked())){
Array<PrimExpr> packed_args{call->args[0]};
for (unsigned i = 1; i<call->args.size(); i++){
// No need to pack inputs of the prim_func
if (inputs_[call->args[i]] == true){
packed_args.push_back(call->args[i]);
}else {
// Pack the argument inside a TVMValue
auto sid_array = tir::Var("tvm_value", DataType::Handle());
tir::Stmt set_struct_stmt = tir::Evaluate(
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrData, call->args[i]}));
new_stmts.push_back(LetStmt(sid_array, StackAlloca("array", 1), set_struct_stmt));
packed_args.push_back(sid_array);
}
}
// Finally, evaluate the packed call and return a sequential statement
new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args)));
return tir::SeqStmt(new_stmts);
}
}
return StmtExprMutator::VisitStmt_(op);
}


private:
InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed.
};

namespace transform {

Pass LegalizePackedCalls(){
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {

auto* n = f.CopyOnWrite();

// Create the
InputMap inputs;
for (auto i : f->params){
inputs[i] = true;
}
n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body));
return std::move(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {});
}
} // namespace transform

} // namespace tir
} // namespace tvm
42 changes: 3 additions & 39 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,35 +138,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
if (op->op.same_as(builtin::address_of())) {
const LoadNode* l = op->args[0].as<LoadNode>();
this->VisitExpr(l->index);
} else if (op->op.same_as(builtin::tvm_call_cpacked())) {
// Recall that the arguments of a tvm_call_cpacked are passed as
// TVMValues. But a TVMValue is only a container, that points to
// a real buffer previously allocated. We need to signal that those
// buffers need to be live at the same time (i.e., cannot be overwritten during the function
// call)
Array<PrimExpr> args = op->args;
for (auto arg : args) {
const VarNode* var = arg.as<VarNode>();
if (value_to_alloc_.find(var) != value_to_alloc_.end()) {
auto allocs = value_to_alloc_[var];
for (const VarNode* alloc : allocs) {
VisitExpr_(alloc);
}
} else {
this->VisitExpr(arg);
}
}
} else if (op->op.same_as(builtin::tvm_struct_set())) {
// If we are using a struct_set built-in, and we are setting
// a DLTensor ArrayData field, let's note down the
// buffers that the TVMValue refers to
const VarNode* var = op->args[0].as<VarNode>();
const VarNode* alloc = op->args[3].as<VarNode>();
const int field_id = op->args[2].as<IntImmNode>()->value;
if (var && alloc && field_id == tir::builtin::kArrData) {
value_to_alloc_[var].push_back(alloc);
}
StmtExprVisitor::VisitExpr_(op);
} else {
StmtExprVisitor::VisitExpr_(op);
}
Expand Down Expand Up @@ -235,13 +206,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
bool in_thread_env_{false};
// The scope stack.
std::vector<StmtEntry> scope_;
// This is a map to connect TVMValues to real allocations. When we pass parameters
// to a tvm_call_cpacked, the data needs to be wrapped in a TVMValue. The wrapping
// happens through the tvm_struct_set built-in. This map is mapping the variable
// representing the TVMValue to the variable representing the real buffer. The live
// analysis needs to happen on the latter and not on the TVMValue which only acts as
// a container.
std::unordered_map<const VarNode*, std::vector<const VarNode*>> value_to_alloc_;
};

// Verify if the statement can be run safely via inplace fashion
Expand Down Expand Up @@ -923,11 +887,11 @@ class StoragePlanRewriter : public StmtExprMutator {
// symbolic free list, for non constant items.
std::list<StorageEntry*> sym_free_list_;
// The allocation attach map
std::unordered_map<const Object*, std::vector<StorageEntry*>> attach_map_;
std::unordered_map<const Object*, std::vector<StorageEntry*> > attach_map_;
// The allocation assign map
std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
// The allocations
std::vector<std::unique_ptr<StorageEntry>> alloc_vec_;
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
// analyzer
arith::Analyzer analyzer_;
};
Expand Down Expand Up @@ -986,7 +950,7 @@ class VectorAllocRewriter : public StmtExprMutator {
}

// Internal access map
std::unordered_map<const VarNode*, std::vector<DataType>> acc_map_;
std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_;
// Variables to remap
Map<tir::Var, PrimExpr> var_remap_;
// internal analyzer
Expand Down

0 comments on commit d062e71

Please sign in to comment.