Skip to content

Commit

Permalink
Predicated Load Optimization apache#2
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored and jinhongyii committed Jun 16, 2022
1 parent 6da285e commit c723660
Show file tree
Hide file tree
Showing 10 changed files with 992 additions and 11 deletions.
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class AttrStmt : public Stmt {
TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode);
};

/*!
Expand Down Expand Up @@ -1551,6 +1552,9 @@ constexpr const char* local_stage = "local_stage";
/*! \brief Mark vectorization length constraint on block */
constexpr const char* vector_bytes = "vector_bytes";

/*! \brief Mark the buffer as cache for buffer load address */
constexpr const char* cached_address = "cached_address";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,8 @@ TVM_DLL Pass LowerVtcmAlloc();
*/
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);

TVM_DLL Pass OptimizePredicatedLoad(bool enable_predicated_load_optimizer = true);

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
* "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,
Expand Down
5 changes: 1 addition & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(
tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir));

pass_list.push_back(tir::transform::OptimizePredicatedLoad(true));
return pass_list;
}

Expand Down Expand Up @@ -446,9 +447,7 @@ std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target&
mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));

IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));

IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target));

auto keys = target->GetKeys();

CheckAndUpdateHostConsistency(&target, &target_host);
Expand All @@ -469,7 +468,6 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
std::vector<runtime::Module> device_modules;
Map<Target, IRModule> inputs = inputs_arg;
Target target_host = target_host_arg;

// Fetch previous defined target host in targets
CheckAndUpdateHostConsistency(&inputs, &target_host);

Expand Down Expand Up @@ -503,7 +501,6 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
auto pair = SplitMixedModule(ir_module, target, target_host);
auto& host_mod = pair.first;
auto& device_mod = pair.second;

ICHECK(host_mod.defined()) << "The split host module must be defined";

ICHECK(mhost_all.defined()) << "The host module must be defined";
Expand Down
53 changes: 49 additions & 4 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src,
}

// Print a reference expression to a buffer.
std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) {
std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index,
bool cached_address) {
const VarNode* buffer_var = buffer->data.get();
std::ostringstream os;
std::string vid = GetVarID(buffer_var);
Expand Down Expand Up @@ -187,6 +188,7 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp

std::string buffer_str = vid;
if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
ICHECK(!cached_address);
std::stringstream temp;
temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
buffer_str = temp.str();
Expand All @@ -201,14 +203,22 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp
// int32. Therefore, we need to divide by the ratio of their
// sizes in that case.
int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();

ICHECK(!cached_address);
os << "*("
<< "(" << ptr_cast(t) << vid << ")"
<< " + " << index_str << " / " << div_factor << ")";
} else if (t == buffer_element_dtype) {
os << buffer_str << "[" << index_str << "]";
if (!cached_address) {
os << buffer_str << "[" << index_str << "]";
} else {
os << "*" << index_str;
}
} else {
os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
if (!cached_address) {
os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
} else {
os << "*" << ptr_cast(t) << "(" << index_str << ")";
}
}

return os.str();
Expand Down Expand Up @@ -673,6 +683,29 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;

// addr[0]
if (cached_address_.count(op->buffer->data)) {
os << GetVarID(buffer_var.get());
return;
}
// data[addr[0]]
if (const BufferLoadNode* load = index.as<BufferLoadNode>()) {
if (cached_address_.count(load->buffer->data)) {
os << GetBufferRef(op->dtype, op->buffer.get(), load->buffer->data, true);
return;
}
}
// data[ramp(addr[0], 1, lanes)]
if (const RampNode* ramp = index.as<RampNode>()) {
if (const BufferLoadNode* load = ramp->base.as<BufferLoadNode>()) {
if (cached_address_.count(load->buffer->data)) {
ICHECK(is_one(ramp->stride));
os << GetBufferRef(op->dtype, op->buffer.get(), load->buffer->data, true);
return;
}
}
}

int lanes = op->dtype.lanes();
// delcare type.
if (value_dtype.lanes() == element_dtype.lanes()) {
Expand Down Expand Up @@ -736,6 +769,18 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
PrimExpr index_expr = op->indices[0];
Var buffer_var = op->buffer->data;

if (cached_address_.count(op->buffer->data)) {
std::string value = this->PrintExpr(op->value);
this->PrintIndent();
if (!is_zero(index_expr)) {
stream << GetVarID(buffer_var.get()) << " = " << GetVarID(index_expr.as<VarNode>()) << " + "
<< value << ";\n";
} else {
stream << GetVarID(buffer_var.get()) << " = " << value << ";\n";
}
return;
}

if (value_dtype.lanes() == element_dtype.lanes()) {
std::string value = this->PrintExpr(op->value);
std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr);
Expand Down
5 changes: 4 additions & 1 deletion src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
// Print reference to struct location
std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
// Print reference to a buffer as type t in index.
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index);
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index,
bool cached_address = false);

/*!
* \brief Handle volatile loads.
Expand Down Expand Up @@ -267,6 +268,8 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
std::unordered_map<const VarNode*, DataType> handle_data_type_;
/*! \brief Record of ops that have pre-defined global symbol. */
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
// Buffers used for address calculation optimization
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> cached_address_;
// cache commonly used ops
const Op& builtin_call_extern_ = builtin::call_extern();
const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
Expand Down
16 changes: 16 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "codegen_cuda.h"

#include <tvm/arith/analyzer.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/stmt_functor.h>
Expand Down Expand Up @@ -925,7 +926,22 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());

auto it = op->annotations.find(tir::attr::cached_address);
if (it != op->annotations.end()) {
cached_address_.insert(op->buffer_var);
this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
ICHECK(scope == "local");
DLDataType dtype =
runtime::String2DLDataType(std::string(Downcast<runtime::String>((*it).second)));
PrintType(DataType(dtype), stream);
stream << "* " << vid << ";\n";
this->PrintStmt(op->body);
return;
}

this->PrintIndent();

std::string scope = GetPtrStorageScope(op->buffer_var);
const VarNode* buffer = op->buffer_var.as<VarNode>();
if (scope.find("wmma.") == 0) {
Expand Down
4 changes: 4 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->Print(op->extents[i]);
}
p->stream << "], storage_scope = " << ptr_type->storage_scope;
if (!op->annotations.empty()) {
p->stream << "], annotations = ";
p->Print(op->annotations);
}
if (!is_one(op->condition)) {
p->stream << " if ";
p->Print(op->condition);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
alloc_size = warp_group_ * factor;

return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)},
op->condition, this->VisitStmt(op->body));
op->condition, this->VisitStmt(op->body), op->annotations);
}

protected:
Expand Down
Loading

0 comments on commit c723660

Please sign in to comment.