From c7236604f823c78e01ed28af69aafba9801ee8c6 Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Mon, 6 Jun 2022 01:57:23 -0400 Subject: [PATCH] Predicated Load Optimization #2 --- include/tvm/tir/stmt.h | 4 + include/tvm/tir/transform.h | 2 + src/driver/driver_api.cc | 5 +- src/target/source/codegen_c.cc | 53 +- src/target/source/codegen_c.h | 5 +- src/target/source/codegen_cuda.cc | 16 + src/tir/ir/stmt.cc | 4 + src/tir/transforms/lower_warp_memory.cc | 2 +- .../transforms/optimize_predicated_load.cc | 910 ++++++++++++++++++ .../update_pointer_storage_scope.cc | 2 +- 10 files changed, 992 insertions(+), 11 deletions(-) create mode 100644 src/tir/transforms/optimize_predicated_load.cc diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 3c2ae0a93b86..19126a0f0f99 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -159,6 +159,7 @@ class AttrStmt : public Stmt { TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode); }; /*! @@ -1551,6 +1552,9 @@ constexpr const char* local_stage = "local_stage"; /*! \brief Mark vectorization length constraint on block */ constexpr const char* vector_bytes = "vector_bytes"; +/*! \brief Mark the buffer as cache for buffer load address */ +constexpr const char* cached_address = "cached_address"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 64ef6a5eb157..5f862534be80 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -476,6 +476,8 @@ TVM_DLL Pass LowerVtcmAlloc(); */ TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false); +TVM_DLL Pass OptimizePredicatedLoad(bool enable_predicated_load_optimizer = true); + /*! * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 8d27d5778374..30318ad58176 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -310,6 +310,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back( tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir)); + pass_list.push_back(tir::transform::OptimizePredicatedLoad(true)); return pass_list; } @@ -446,9 +447,7 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); - IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); - auto keys = target->GetKeys(); CheckAndUpdateHostConsistency(&target, &target_host); @@ -469,7 +468,6 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, std::vector device_modules; Map inputs = inputs_arg; Target target_host = target_host_arg; - // Fetch previous defined target host in targets CheckAndUpdateHostConsistency(&inputs, &target_host); @@ -503,7 +501,6 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, auto pair = SplitMixedModule(ir_module, target, target_host); auto& host_mod = pair.first; auto& device_mod = pair.second; - ICHECK(host_mod.defined()) << "The split host module must be defined"; ICHECK(mhost_all.defined()) << "The host module must be defined"; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 3ad7882d792c..788f07f59386 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -159,7 +159,8 @@ void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, } // Print a reference expression to a buffer. -std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) { +std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index, + bool cached_address) { const VarNode* buffer_var = buffer->data.get(); std::ostringstream os; std::string vid = GetVarID(buffer_var); @@ -187,6 +188,7 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp std::string buffer_str = vid; if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) { + ICHECK(!cached_address); std::stringstream temp; temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")"; buffer_str = temp.str(); @@ -201,14 +203,22 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp // int32. Therefore, we need to divide by the ratio of their // sizes in that case. int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes(); - + ICHECK(!cached_address); os << "*(" << "(" << ptr_cast(t) << vid << ")" << " + " << index_str << " / " << div_factor << ")"; } else if (t == buffer_element_dtype) { - os << buffer_str << "[" << index_str << "]"; + if (!cached_address) { + os << buffer_str << "[" << index_str << "]"; + } else { + os << "*" << index_str; + } } else { - os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; + if (!cached_address) { + os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; + } else { + os << "*" << ptr_cast(t) << "(" << index_str << ")"; + } } return os.str(); @@ -673,6 +683,29 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI Var buffer_var = op->buffer->data; DataType element_dtype = op->buffer->dtype; + // addr[0] + if (cached_address_.count(op->buffer->data)) { + os << GetVarID(buffer_var.get()); + return; + } + // data[addr[0]] + if (const BufferLoadNode* load = index.as()) { + if (cached_address_.count(load->buffer->data)) { + os << GetBufferRef(op->dtype, op->buffer.get(), load->buffer->data, true); + return; + } + } + // data[ramp(addr[0], 1, lanes)] + if (const RampNode* ramp = index.as()) { + if (const BufferLoadNode* load = ramp->base.as()) { + if (cached_address_.count(load->buffer->data)) { + ICHECK(is_one(ramp->stride)); + os << GetBufferRef(op->dtype, op->buffer.get(), load->buffer->data, true); + return; + } + } + } + int lanes = op->dtype.lanes(); // delcare type. if (value_dtype.lanes() == element_dtype.lanes()) { @@ -736,6 +769,18 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { PrimExpr index_expr = op->indices[0]; Var buffer_var = op->buffer->data; + if (cached_address_.count(op->buffer->data)) { + std::string value = this->PrintExpr(op->value); + this->PrintIndent(); + if (!is_zero(index_expr)) { + stream << GetVarID(buffer_var.get()) << " = " << GetVarID(index_expr.as()) << " + " + << value << ";\n"; + } else { + stream << GetVarID(buffer_var.get()) << " = " << value << ";\n"; + } + return; + } + if (value_dtype.lanes() == element_dtype.lanes()) { std::string value = this->PrintExpr(op->value); std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 696ec62c5870..c35e69ea3633 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -203,7 +203,8 @@ class CodeGenC : public ExprFunctor, // Print reference to struct location std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); // Print reference to a buffer as type t in index. - virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index); + virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index, + bool cached_address = false); /*! * \brief Handle volatile loads. @@ -267,6 +268,8 @@ class CodeGenC : public ExprFunctor, std::unordered_map handle_data_type_; /*! \brief Record of ops that have pre-defined global symbol. */ OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); + // Buffers used for address calculation optimization + std::unordered_set cached_address_; // cache commonly used ops const Op& builtin_call_extern_ = builtin::call_extern(); const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 616e75f2e776..6d92ce376e22 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -24,6 +24,7 @@ #include "codegen_cuda.h" #include +#include #include #include #include @@ -925,7 +926,22 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { ICHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); + auto it = op->annotations.find(tir::attr::cached_address); + if (it != op->annotations.end()) { + cached_address_.insert(op->buffer_var); + this->PrintIndent(); + std::string scope = GetPtrStorageScope(op->buffer_var); + ICHECK(scope == "local"); + DLDataType dtype = + runtime::String2DLDataType(std::string(Downcast((*it).second))); + PrintType(DataType(dtype), stream); + stream << "* " << vid << ";\n"; + this->PrintStmt(op->body); + return; + } + this->PrintIndent(); + std::string scope = GetPtrStorageScope(op->buffer_var); const VarNode* buffer = op->buffer_var.as(); if (scope.find("wmma.") == 0) { diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 43c2d3745964..08ac316eceaf 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -424,6 +424,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->extents[i]); } p->stream << "], storage_scope = " << ptr_type->storage_scope; + if (!op->annotations.empty()) { + p->stream << "], annotations = "; + p->Print(op->annotations); + } if (!is_one(op->condition)) { p->stream << " if "; p->Print(op->condition); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index d8250cd09888..bd5eec110ea0 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -253,7 +253,7 @@ class WarpAccessRewriter : protected StmtExprMutator { alloc_size = warp_group_ * factor; return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)}, - op->condition, this->VisitStmt(op->body)); + op->condition, this->VisitStmt(op->body), op->annotations); } protected: diff --git a/src/tir/transforms/optimize_predicated_load.cc b/src/tir/transforms/optimize_predicated_load.cc new file mode 100644 index 000000000000..51f607ea044a --- /dev/null +++ b/src/tir/transforms/optimize_predicated_load.cc @@ -0,0 +1,910 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../arith/const_fold.h" +#include "../../arith/pattern_match.h" + +namespace tvm { +namespace tir { + +IntImm int32(int value) { return IntImm(DataType::Int(32), value); } + +struct Attach { + public: + enum class AttachType : int { + kAddition = 0, + kFloordiv = 1, + kFloormod = 2, + }; + + // type == kAddition: cur_var = dependent_var * c1 + c2 + // type == kFloordiv: cur_var = floordiv(dependent_var, c1) + // type == kFloormod: cur_var = floormod(dependent_var, c1) + // dependent var finally depends on attach_loop + Var attach_loop, cur_var, dependent_var; + AttachType type; + IntImm c1, c2; + + Attach() {} + Attach(Var attach_loop, Var cur_var, Var dependent_var, AttachType type, IntImm c1, IntImm c2) + : attach_loop(attach_loop), + cur_var(cur_var), + dependent_var(dependent_var), + type(type), + c1(c1), + c2(c2) {} + + PrimExpr Init(const PrimExpr& init) const { + if (type == AttachType::kAddition) { + return init * c1 + c2; + } else if (type == AttachType::kFloordiv) { + return floordiv(init, c1); + } + return floormod(init, c1); + } + + arith::IntSet Range(const arith::IntSet& range) const { + if (type == AttachType::kAddition) { + return arith::EvalSet(dependent_var * c1 + c2, {{dependent_var, range}}); + } else if (type == AttachType::kFloordiv) { + return arith::EvalSet(floordiv(dependent_var, c1), {{dependent_var, range}}); + } + return arith::EvalSet(floormod(dependent_var, c1), {{dependent_var, range}}); + } +}; + +/*! + * \brief Fuse multiple iterators by summing them with scaling. + * result = sum_{i} (vars[i] * scale[i]) + base + */ +class SumFormNode : public PrimExprNode { + public: + Array vars; + Array scales; + IntImm base; + + // overrides + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("vars", &vars); + v->Visit("scales", &scales); + v->Visit("base", &base); + } + + bool SEqualReduce(const SumFormNode* other, SEqualReducer equal) const { + return equal(vars, other->vars) && equal(scales, other->scales) && equal(base, other->base); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vars); + hash_reduce(scales); + hash_reduce(base); + } + + static constexpr const char* _type_key = "SumForm"; + TVM_DECLARE_FINAL_OBJECT_INFO(SumFormNode, PrimExprNode); +}; + +class SumForm : public PrimExpr { + public: + /*! + * \brief constructor. + * \param vars The vars to the sum. + * \param scale The scales to multiply. + * \param base The base + */ + SumForm(Array vars, Array scales, IntImm base) { + ICHECK_EQ(vars.size(), scales.size()); + auto n = make_object(); + n->dtype = base->dtype; + n->vars = std::move(vars); + n->scales = std::move(scales); + n->base = std::move(base); + data_ = std::move(n); + } + + TVM_DEFINE_OBJECT_REF_METHODS(SumForm, PrimExpr, SumFormNode); +}; + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "SumForm(" << op->vars << ", " << op->scales << ", " << op->base << ")"; + }); + +SumForm AddSumForm(const SumForm& a, const SumForm& b, bool neg = false) { + int coeff = neg ? -1 : 1; + std::unordered_set covered; + std::vector vars; + std::vector scales; + for (size_t i = 0; i < a->vars.size(); ++i) { + vars.push_back(a->vars[i]); + covered.insert(a->vars[i]); + + size_t j; + for (j = 0; j < b->vars.size(); ++j) { + if (a->vars[i].same_as(b->vars[j])) { + break; + } + } + scales.push_back(j == b->vars.size() + ? a->scales[i] + : int32(a->scales[i]->value + coeff * b->scales[j]->value)); + } + for (size_t i = 0; i < b->vars.size(); ++i) { + if (!covered.count(b->vars[i])) { + vars.push_back(b->vars[i]); + scales.push_back(int32(coeff * b->scales[i]->value)); + } + } + return SumForm(std::move(vars), std::move(scales), + int32(a->base->value + coeff * b->base->value)); +} + +SumForm MulSumForm(const SumForm& a, const IntImm& b) { + std::vector scales; + for (const IntImm scale : a->scales) { + scales.push_back(int32(scale->value * b->value)); + } + return SumForm(a->vars, scales, int32(a->base->value * b->value)); +} + +class LetVarBindingCanonicalizer : public ExprMutator { + public: + explicit LetVarBindingCanonicalizer( + std::unordered_map* var_range) + : var_range_(var_range) {} + + bool Canonicalize(const Var& top_var, const PrimExpr& binding) { + top_let_var_[binding] = top_var; + PrimExpr res = this->VisitExpr(binding); + if (fail) return false; + + const SumFormNode* ret = res.as(); + ICHECK(ret != nullptr); + ICHECK_EQ(ret->vars.size(), 1); + if (!is_one(ret->scales[0]) || !is_zero(ret->base)) { + let_var_buffer_map[top_var] = + decl_buffer({int32(1)}, DataType::Int(32), top_var->name_hint, "local"); + BuildAttachMap(top_var, ret->vars[0], Attach::AttachType::kAddition, ret->scales[0], + ret->base); + } + return true; + } + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> inv_attach_map; + std::unordered_map let_var_buffer_map; + std::unordered_map attach_map; + + private: + void BuildAttachMap(const Var& cur_var, const Var& dependent_var, Attach::AttachType type, + IntImm c1, IntImm c2) { + // Append new attach + Attach attach; + auto it = attach_map.find(dependent_var); + if (it == attach_map.end()) { + attach = Attach(dependent_var, cur_var, dependent_var, type, c1, c2); + } else { + attach = Attach(it->second.attach_loop, cur_var, dependent_var, type, c1, c2); + } + inv_attach_map[dependent_var].push_back(attach); + attach_map[cur_var] = attach; + // calculate var range + (*var_range_)[cur_var] = attach.Range((*var_range_)[dependent_var]); + } + + Optional SearchExisitingAttach(const Var& dependent_var, Attach::AttachType type, + IntImm c1) { + auto it = inv_attach_map.find(dependent_var); + if (it != inv_attach_map.end()) { + for (const Attach& attach : it->second) { + if (attach.dependent_var.same_as(dependent_var) && attach.type == type && + attach.c1->value == c1->value) { + return attach.cur_var; + } + } + } + return NullOpt; + } + + PrimExpr VisitExprDefault_(const Object* op) final { + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const IntImmNode* op) final { return SumForm({}, {}, int32(op->value)); } + + PrimExpr VisitExpr_(const VarNode* op) final { + return SumForm({GetRef(op)}, {int32(1)}, int32(0)); + } + + PrimExpr VisitExpr_(const FloorDivNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + ICHECK(a->vars.size() <= 1 && b->vars.size() <= 1); + if (b->vars.size() == 0) { + // define let var for a + Var inner; + if (is_one(a->scales[0]) && is_zero(a->base)) { + // we don't have to introduce an intermediate var + inner = a->vars[0]; + } else { + // introduce an intermediate var + inner = a->vars[0].copy_with_suffix("_lin"); + let_var_buffer_map[inner] = + decl_buffer({int32(1)}, DataType::Int(32), inner->name_hint, "local"); + BuildAttachMap(inner, a->vars[0], Attach::AttachType::kAddition, a->scales[0], a->base); + } + // define let var for div, and a conjugate let var for mod + // first search for existing vars + Optional var_div = + SearchExisitingAttach(inner, Attach::AttachType::kFloordiv, b->base); + Optional var_mod = + SearchExisitingAttach(inner, Attach::AttachType::kFloormod, b->base); + // introduce new intermediate vars if doesn't exsit now + if (!var_div.defined()) { + auto it = top_let_var_.find(GetRef(op)); + var_div = it == top_let_var_.end() + ? inner.copy_with_suffix("_div_" + std::to_string(b->base->value)) + : it->second; + let_var_buffer_map[var_div.value()] = + decl_buffer({int32(1)}, DataType::Int(32), var_div.value()->name_hint, "local"); + BuildAttachMap(var_div.value(), inner, Attach::AttachType::kFloordiv, b->base, int32(0)); + } + if (!var_mod.defined()) { + var_mod = inner.copy_with_suffix("_mod_" + std::to_string(b->base->value)); + let_var_buffer_map[var_mod.value()] = + decl_buffer({int32(1)}, DataType::Int(32), var_mod.value()->name_hint, "local"); + BuildAttachMap(var_mod.value(), inner, Attach::AttachType::kFloormod, b->base, int32(0)); + } + return SumForm({var_div.value()}, {int32(1)}, int32(0)); + } + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const FloorModNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + ICHECK(a->vars.size() <= 1 && b->vars.size() <= 1); + if (b->vars.size() == 0) { + // define let var for a + Var inner; + if (is_one(a->scales[0]) && is_zero(a->base)) { + // we don't have to introduce an intermediate var + inner = a->vars[0]; + } else { + // introduce an intermediate var + inner = a->vars[0].copy_with_suffix("_lin_"); + let_var_buffer_map[inner] = + decl_buffer({int32(1)}, DataType::Int(32), inner->name_hint, "local"); + BuildAttachMap(inner, a->vars[0], Attach::AttachType::kAddition, a->scales[0], a->base); + } + // define let var for mod + // first search for existing vars + Optional var_mod = + SearchExisitingAttach(inner, Attach::AttachType::kFloormod, b->base); + // introduce new intermediate var if doesn't exsits now + if (!var_mod.defined()) { + auto it = top_let_var_.find(GetRef(op)); + var_mod = it == top_let_var_.end() + ? inner.copy_with_suffix("_mod_" + std::to_string(b->base->value)) + : it->second; + let_var_buffer_map[var_mod.value()] = + decl_buffer({int32(1)}, DataType::Int(32), var_mod.value()->name_hint, "local"); + BuildAttachMap(var_mod.value(), inner, Attach::AttachType::kFloormod, b->base, int32(0)); + } + return SumForm({var_mod.value()}, {int32(1)}, int32(0)); + } + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const AddNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + ICHECK(a->vars.size() <= 1 && b->vars.size() <= 1); + SumForm ret = AddSumForm(GetRef(a), GetRef(b)); + if (ret->vars.size() <= 1) { + return ret; + } + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const SubNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + ICHECK(a->vars.size() <= 1 && b->vars.size() <= 1); + SumForm ret = AddSumForm(GetRef(a), GetRef(b), true); + if (ret->vars.size() <= 1) { + return ret; + } + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const MulNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + ICHECK(a->vars.size() <= 1 && b->vars.size() <= 1); + if (a->vars.size() == 0) { + return MulSumForm(GetRef(b), a->base); + } else if (b->vars.size() == 0) { + return MulSumForm(GetRef(a), b->base); + } + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + bool fail{false}; + std::unordered_map top_let_var_; + + std::unordered_map* var_range_; +}; + +class LoadAddressLinearizer : public ExprMutator { + public: + explicit LoadAddressLinearizer( + std::unordered_map* var_range) + : var_range_(var_range) {} + + bool Linearize(const PrimExpr& addr) { + result = Downcast(this->VisitExpr(addr)); + return !fail; + } + + SumForm result; + + private: + PrimExpr VisitExprDefault_(const Object* op) final { + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const IntImmNode* op) final { return SumForm({}, {}, int32(op->value)); } + + PrimExpr VisitExpr_(const VarNode* op) final { + return var_range_->count(GetRef(op)) ? SumForm({GetRef(op)}, {int32(1)}, int32(0)) + : SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const FloorDivNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + if (a->vars.size() == 0 && b->vars.size() == 0) { + return SumForm({}, {}, int32(0)); + } + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const FloorModNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + if (a->vars.size() == 0 && b->vars.size() == 0) { + return SumForm({}, {}, int32(0)); + } + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const AddNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + return AddSumForm(GetRef(a), GetRef(b)); + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const SubNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + return AddSumForm(GetRef(a), GetRef(b)); + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + PrimExpr VisitExpr_(const MulNode* op) final { + PrimExpr ret_a = this->VisitExpr(op->a); + PrimExpr ret_b = this->VisitExpr(op->b); + const SumFormNode* a = ret_a.as(); + const SumFormNode* b = ret_b.as(); + if (a != nullptr && b != nullptr) { + if (a->vars.size() == 0 && !is_zero(a->base)) { + return MulSumForm(GetRef(b), a->base); + } else if (b->vars.size() == 0 && !is_zero(b->base)) { + return MulSumForm(GetRef(a), b->base); + } + } + fail = true; + return SumForm({}, {}, int32(0)); + } + + bool fail{false}; + std::unordered_map* var_range_; +}; + +class PredicatePrecompute : public StmtMutator { + public: + Stmt VisitStmt_(const LetStmtNode* op) final { + let_stmt_stack_.push_back(GetRef(op)); + Stmt result = StmtMutator::VisitStmt_(op); + let_stmt_stack_.pop_back(); + return result; + } + + Stmt VisitStmt_(const ForNode* op) final { + // enter + var_range_[op->loop_var] = arith::IntSet::FromMinExtent(op->min, op->extent); + For result = Downcast(StmtMutator::VisitStmt_(op)); + // Handle attach map + auto* new_for = result.CopyOnWrite(); + auto it = attach_map_.find(op->loop_var); + if (it != attach_map_.end()) { + // append index update stmts inside the loop + new_for->body = + AppendStmt(new_for->body, AttachUpdateStmt(op->loop_var, int32(1), it->second)); + } + // Handle addr update map + auto itt = addr_update_map_.find(op->loop_var); + if (itt != addr_update_map_.end()) { + std::vector outside{result}; + // append addr update stmts inside and outside the loop + for (const auto info : itt->second) { + new_for->body = AppendStmt( + new_for->body, + BufferStore(info.first, BufferLoad(info.first, {int32(0)}) + info.second, {int32(0)})); + outside.push_back(BufferStore( + info.first, BufferLoad(info.first, {int32(0)}) - info.second * op->extent, {int32(0)})); + } + return SeqStmt::Flatten(outside); + } + return result; + } + + Stmt VisitStmt_(const AttrStmtNode* attr) final { + pre_computed = false; + AttrStmt result = Downcast(StmtMutator::VisitStmt_(attr)); + auto* result_ptr = result.CopyOnWrite(); + // append the pre-computation of predicate buffers + if (!pre_computed) { + // append the initialization of addr buffer + std::unordered_map replace; + for (const auto it : let_var_buffer_map_) { + replace[it.first] = Load(it.first); + } + for (const auto it : var_range_) { + if (!let_var_buffer_map_.count(it.first)) { + replace[it.first] = it.second.min(); + } + } + for (const auto it : addr_map_) { + result_ptr->body = AppendStmt( + BufferStore(it.first, Substitute(it.second.first, replace), {it.second.second->data}), + result_ptr->body); + } + // append the initilziation of index buffers + for (const auto it : attach_map_) { + const std::vector& attaches = it.second; + std::unordered_map init_map; + for (const Attach& attach : attaches) { + PrimExpr value; + auto itt = init_map.find(attach.dependent_var); + if (itt == init_map.end()) { + // depend on loop var + ICHECK(var_range_.count(attach.dependent_var)); + value = attach.Init(var_range_[attach.dependent_var].min()); + } else { + value = attach.Init(itt->second); + } + init_map[attach.cur_var] = value; + result_ptr->body = AppendStmt(Store(attach.cur_var, value), result_ptr->body); + } + } + attach_map_.clear(); + for (const auto it : predicate_map_) { + result_ptr->body = AppendStmt(it.second, result_ptr->body); + } + pre_computed = true; + } + return result; + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { return TransformPredicateLoad(store); } + + Stmt VisitStmt_(const AllocateNode* allocate) final { + Stmt result = StmtMutator::VisitStmt_(allocate); + // append the definition of predicate buffers, index buffers and addr buffer + for (const auto it : predicate_map_) { + result = Allocate(it.first->data, it.first->dtype, it.first->shape, Bool(true), result); + } + for (const auto it : let_var_buffer_map_) { + result = Allocate(it.second->data, it.second->dtype, it.second->shape, Bool(true), result); + } + for (const auto it : addr_map_) { + result = Allocate( + it.first->data, it.first->dtype, it.first->shape, Bool(true), result, + {{attr::cached_address, runtime::String(DLDataType2String(it.second.second->dtype))}}); + } + predicate_map_.clear(); + let_var_buffer_map_.clear(); + addr_map_.clear(); + return result; + } + + private: + PrimExpr Load(const Var& var) { + const auto it = let_var_buffer_map_.find(var); + if (it == let_var_buffer_map_.end()) { + return var; + } else { + return BufferLoad(it->second, {int32(0)}); + } + } + + Stmt Store(const Var& var, const PrimExpr& value) { + const auto it = let_var_buffer_map_.find(var); + ICHECK(it != let_var_buffer_map_.end()); + return BufferStore(it->second, value, {int32(0)}); + } + + private: + void AppendVarUpdate(std::vector* body, const Var& var, IntImm delta, + const std::vector& attaches) { + body->push_back(Store(var, Load(var) + delta)); + const auto it = addr_update_map_.find(var); + if (it != addr_update_map_.end()) { + for (const auto info : it->second) { + body->push_back(BufferStore( + info.first, BufferLoad(info.first, {int32(0)}) + delta * info.second, {int32(0)})); + } + } + body->push_back(AttachUpdateStmt(var, delta, attaches)); + } + + Stmt AttachUpdateStmt(Var dependent_var, IntImm inc, const std::vector& attaches) { + std::vector result; + ICHECK(inc->value != 0); + for (const Attach& attach : attaches) { + if (attach.dependent_var.same_as(dependent_var)) { + if (attach.type == Attach::AttachType::kAddition) { + IntImm delta = Downcast(inc * attach.c1); + result.push_back(Store(attach.cur_var, Load(attach.cur_var) + delta)); + result.push_back(AttachUpdateStmt(attach.cur_var, delta, attaches)); + } else if (attach.type == Attach::AttachType::kFloormod) { + // Search for conjugate div attach + size_t j; + for (j = 0; j < attaches.size(); ++j) { + if (attaches[j].dependent_var.same_as(attach.dependent_var) && + attaches[j].type == Attach::AttachType::kFloordiv && + attaches[j].c1->value == attach.c1->value) { + break; + } + } + // x <- x + C + // floormod(x + C, c1) <- floormod(floormod(x, c1) + floormod(C, c1), c1) + // 1) = floormod(x, c1) + floormod(C, c1) + // floordiv(x + C, c1) <- floordiv(x, c1) + floodiv(C, c1) + // 2) = floormod(x, c1) + floormod(C, c1) - c1 (if overflow) + // floordiv(x + C, c1) <- floordiv(x, c1) + floodiv(C, c1) + 1 + // 3) = floormod(x, c1) + floormod(C, c1) + c1 (if underflow) + // floordiv(x + C, c1) <- floordiv(x, c1) + floodiv(C, c1) - 1 + IntImm delta_mod = Downcast(floormod(inc, attach.c1)); + if (delta_mod->value != 0) { + AppendVarUpdate(&result, attach.cur_var, delta_mod, attaches); + } + if (j < attaches.size()) { + const Attach& attach_div = attaches[j]; + IntImm delta_div = Downcast(floordiv(inc, attach.c1)); + if (delta_div->value != 0) { + AppendVarUpdate(&result, attach_div.cur_var, delta_div, attaches); + } + } + // Construct the if body + std::vector if_body; + int sign = delta_mod->value > 0 ? -1 : 1; + IntImm delta_mod_if = int32(sign * attach.c1->value); + AppendVarUpdate(&if_body, attach.cur_var, delta_mod_if, attaches); + if (j < attaches.size()) { + const Attach& attach_div = attaches[j]; + IntImm delta_div_if = int32(-sign); + AppendVarUpdate(&if_body, attach_div.cur_var, delta_div_if, attaches); + } + result.push_back(IfThenElse(delta_mod->value > 0 + ? (greater_equal(Load(attach.cur_var), attach.c1)) + : less(Load(attach.cur_var), int32(0)), + SeqStmt::Flatten(if_body))); + } + // Floordiv will be updated together with FloorMod + } + } + return SeqStmt::Flatten(result); + } + + bool MatchLetVars(LetVarBindingCanonicalizer* canonicalizer) { + for (const LetStmt& let : let_stmt_stack_) { + if (!canonicalizer->Canonicalize(let->var, let->value)) { + return false; + } + } + return true; + } + + void SplitPredicate(PrimExpr predicate, std::vector* sub_predicates) { + arith::PVar sub_predicate, rest; + for (;;) { + if ((rest && sub_predicate).Match(predicate)) { + sub_predicates->push_back(sub_predicate.Eval()); + predicate = rest.Eval(); + } else { + sub_predicates->push_back(predicate); + return; + } + } + } + + BufferStore TransformPredicateLoad(const BufferStoreNode* store) { + // Canonicalize the let var bindings + LetVarBindingCanonicalizer canonicalizer(&var_range_); + LoadAddressLinearizer linearizer(&var_range_); + if (!MatchLetVars(&canonicalizer)) return GetRef(store); + local_predicate_map_.clear(); + // Check the pattern of load address and predicate + const CallNode* call = store->value.as(); + if (call != nullptr) { + const OpNode* op = call->op.as(); + if (op != nullptr && op->name == "tir.if_then_else") { + ICHECK_EQ(call->args.size(), 3); + const PrimExpr& predicate = call->args[0]; + const PrimExpr& lhs = call->args[1]; + const PrimExpr& rhs = call->args[2]; + // handle load address + PrimExpr addr; + bool lhs_fail{true}; + const BufferLoadNode* load = lhs.as(); + if (load != nullptr) { + addr = load->indices[0]; + if (const RampNode* ramp = load->indices[0].as()) { + addr = ramp->base; + } + if (linearizer.Linearize(addr)) { + lhs_fail = false; + } + } + // handle predicate + if (!lhs_fail) { + // split predicates into sub predicates + std::vector sub_predicates, new_sub_predicates; + SplitPredicate(predicate, &sub_predicates); + // Note down let var buffer map + let_var_buffer_map_.insert(canonicalizer.let_var_buffer_map.begin(), + canonicalizer.let_var_buffer_map.end()); + // parameterize sub-predicates + for (const PrimExpr& sub_predicate : sub_predicates) { + std::unordered_set covered; + auto collect = [&](const ObjectRef& obj) -> bool { + if (const tir::VarNode* var = obj.as()) { + if (!threads_.count(var->name_hint)) { + covered.insert(GetRef(var)); + } + } + return true; + }; + PreOrderVisit(sub_predicate, collect); + // validate parameterization + if (covered.size() == 0) { + // we don't have to pre-compute this sub-predicate + new_sub_predicates.push_back(sub_predicate); + } else if (covered.size() == 1) { + // p(var), max(var) <= 32 + const Var& var = *(covered.begin()); + const IntImmNode* min = var_range_[var].min().as(); + const IntImmNode* max = var_range_[var].max().as(); + if (min != nullptr && max != nullptr && max->value <= 32) { + // allocate buffer for this sub-predicate + const Buffer& buffer = + decl_buffer({int32(1)}, DataType::Int(32), "predicate", "local"); + // Create pre-compute loops for this sub-predicate + Var loop_var = var.copy_with_suffix("_pre"); + Stmt init = BufferStore(buffer, 0, {0}); + For compute = For( + loop_var, int32(min->value), int32(max->value - min->value + 1), + ForKind::kSerial, + BufferStore(buffer, + BufferLoad(buffer, {0}) | + ((Substitute(sub_predicate, {{var, loop_var}})) << loop_var), + {0})); + local_predicate_map_[buffer] = SeqStmt({init, compute}); + // rewrite this sub-prediate + new_sub_predicates.push_back((BufferLoad(buffer, {0}) >> Load(var)) & 1); + continue; + } + } else if (covered.size() == 2) { + // p(var1, var2), max(max(var1), max(var2)) <= 32, min(max(var1), max(var2)) <= 5 + auto it = covered.begin(); + Var var1 = (*it); + Var var2 = (*(++it)); + const IntImmNode* min1 = var_range_[var1].min().as(); + const IntImmNode* max1 = var_range_[var1].max().as(); + const IntImmNode* min2 = var_range_[var2].min().as(); + const IntImmNode* max2 = var_range_[var2].max().as(); + if (max1 != nullptr && max2 != nullptr) { + if (max1->value > max2->value) { + std::swap(var1, var2); + std::swap(min1, min2); + std::swap(max1, max2); + } + if (max1->value <= 5 && max2->value <= 64) { + // allocate buffer for this sub-predicate + const Buffer& buffer = decl_buffer({GetRef(max1) + 1}, DataType::Int(32), + "predicate", "local"); + local_predicate_map_[buffer] = Evaluate(0); + // Create pre-compute loops for this sub-predicate + Var loop_var1 = var1.copy_with_suffix("_pre"); + Var loop_var2 = var2.copy_with_suffix("_pre"); + For compute = For( + loop_var1, int32(min1->value), int32(max1->value - min1->value + 1), + ForKind::kSerial, + SeqStmt({BufferStore(buffer, 0, {loop_var1}), + For(loop_var2, int32(min2->value), + int32(max2->value - min2->value + 1), ForKind::kSerial, + BufferStore(buffer, + BufferLoad(buffer, {loop_var1}) | + ((Substitute(sub_predicate, {{var1, loop_var1}, + {var2, loop_var2}})) + << loop_var2), + {loop_var1}))})); + local_predicate_map_[buffer] = compute; + // rewrite this sub-prediate + new_sub_predicates.push_back((BufferLoad(buffer, {Load(var1)}) >> Load(var2)) & + 1); + continue; + } + } + } + // fail case + return GetRef(store); + } + // Note down attach map + for (const auto it : canonicalizer.attach_map) { + attach_map_[it.second.attach_loop].push_back(it.second); + } + // Note down predicate buffers + predicate_map_.insert(local_predicate_map_.begin(), local_predicate_map_.end()); + // Make new predicate + PrimExpr new_predicate = new_sub_predicates[0]; + for (size_t i = 1; i < new_sub_predicates.size(); ++i) { + new_predicate = new_sub_predicates[i] & new_predicate; + } + // introduce a var for this addr + const Buffer& buffer = decl_buffer({int32(1)}, DataType::Int(32), "addr", "local"); + addr_map_[buffer] = std::make_pair(addr, load->buffer); + // note down update info + for (size_t i = 0; i < linearizer.result->vars.size(); ++i) { + addr_update_map_[linearizer.result->vars[i]].push_back( + std::make_pair(buffer, linearizer.result->scales[i])); + } + PrimExpr new_lhs; + if (const RampNode* ramp = load->indices[0].as()) { + new_lhs = BufferLoad(load->buffer, + {Ramp(BufferLoad(buffer, {int32(0)}), ramp->stride, ramp->lanes)}); + } else { + new_lhs = BufferLoad(load->buffer, {BufferLoad(buffer, {int32(0)})}); + } + return BufferStore( + store->buffer, + if_then_else(cast(DataType::Bool(1), new_predicate), new_lhs, rhs, store->span), + store->indices); + } + } + } + return GetRef(store); + } + + Stmt AppendStmt(Stmt body, Stmt stmt) { + const SeqStmtNode* body_ptr = body.as(); + if (body_ptr == nullptr) { + return SeqStmt::Flatten(body, stmt); + } else { + return SeqStmt::Flatten(body_ptr->seq, stmt); + } + } + + arith::Analyzer analyzer_; + std::vector let_stmt_stack_; + bool pre_computed{false}; + std::unordered_map let_var_buffer_map_; + std::unordered_map var_range_; + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> attach_map_; + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> addr_map_; + std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> + addr_update_map_; + std::unordered_map predicate_map_; + std::unordered_map local_predicate_map_; + + std::unordered_set threads_{"blockIdx.x", "blockIdx.y", "blockIdx.z", + "threadIdx.x", "threadIdx.y", "threadIdx.z"}; +}; + +namespace transform { + +Pass OptimizePredicatedLoad(bool enable_predicated_load_optimizer) { + auto pass_func = [enable_predicated_load_optimizer](PrimFunc f, IRModule m, PassContext ctx) { + if (enable_predicated_load_optimizer) { + auto* n = f.CopyOnWrite(); + n->body = PredicatePrecompute()(n->body); + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.OptimizePredicatedLoad", {}); +} + +// The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it +TVM_REGISTER_GLOBAL("tir.transform.OptimizePredicatedLoad").set_body_typed(OptimizePredicatedLoad); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 69db85eda2df..9424440e18af 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -61,7 +61,7 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), - StmtExprMutator::VisitStmt(op->body)); + StmtExprMutator::VisitStmt(op->body), op->annotations); } template