diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h index f842780bec7c..fddd34bcaab8 100644 --- a/src/arith/compute_expr.h +++ b/src/arith/compute_expr.h @@ -56,27 +56,6 @@ template inline PrimExpr ComputeReduce( const Array& values, PrimExpr empty_value); -inline bool GetConst(PrimExpr e, int64_t* out) { - if (e.dtype().is_vector()) return false; - const int64_t* v = tir::as_const_int(e); - if (v) { - *out = *v; return true; - } else { - return false; - } -} - -// get a small constant int -inline bool GetConstInt(PrimExpr e, int* out) { - int64_t v1 = 0; - if (GetConst(e, &v1)) { - if (v1 > static_cast( - std::numeric_limits::max())) return false; - *out = static_cast(v1); return true; - } - return false; -} - template<> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a + b; diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index e81b0881f927..0920ed3d0712 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -574,6 +574,17 @@ ramp(const Pattern& base, base.derived(), stride.derived(), lanes.derived()); } +template +inline PRampExpr, PConst> +ramp(const Pattern& base, + int stride, + int lanes) { + return PRampExpr, PConst>( + base.derived(), + PConstWithTypeLike(base.derived(), stride), + PConst(lanes)); +} + /*! * \brief Pattern broadcast expression. * \tparam TA The pattern type of the value. diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 14302efe82fc..820a20c802fb 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -30,6 +30,7 @@ #include "codegen_llvm.h" #include "codegen_cpu.h" +#include "../../arith/pattern_match.h" #include "../build_common.h" namespace tvm { namespace codegen { @@ -363,27 +364,27 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, md_builder_->createTBAAStructTagNode(meta, meta, 0)); return; } - int base = 0, width = 0; + + int64_t base = 0, width = 0; + arith::PVar pbase, pstride; + arith::PVar planes; // create meta-data for alias analysis // Use a group of binary tree ranges of memory banks. if (index.defined()) { - const RampNode* ramp = index.as(); - if (ramp) { - int base, stride; - if (arith::GetConstInt(ramp->base, &base) && - arith::GetConstInt(ramp->stride, &stride)) { - int xwith = ramp->lanes * stride; - width = 1; - while (width < xwith) { - width *= 2; - } - while (base % width) { - base -= base % width; - width *= 2; - } + if (arith::ramp(pbase, pstride, planes).Match(index)) { + base = pbase.Eval()->value; + int64_t xwith = planes.Eval() * pstride.Eval()->value; + width = 1; + while (width < xwith) { + width *= 2; } - } else { - if (arith::GetConstInt(index, &base)) width = 1; + while (base % width) { + base -= base % width; + width *= 2; + } + } else if (auto* ptr = index.as()) { + width = 1; + base = ptr->value; } } llvm::MDNode* meta = md_tbaa_root_; @@ -394,8 +395,8 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta); // create a tree-shape access structure. if (width != 0) { - for (int w = 1024; w >= width; w /= 2) { - int b = (base / w) * w; + for (int64_t w = 1024; w >= width; w /= 2) { + int64_t b = (base / w) * w; std::stringstream os; os << buffer << ".w" << w << ".b" << b; meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 6e7784c81f85..84604b8a0aed 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -23,8 +23,8 @@ #include #include #include "codegen_c.h" +#include "../../arith/pattern_match.h" #include "../../arith/compute_expr.h" -#include "../../tir/pass/ir_util.h" namespace tvm { namespace codegen { @@ -198,8 +198,8 @@ std::string CodeGenC::GetBufferRef( // optimize for case where it is in register, if (HandleTypeMatch(buffer, t) && !is_vol) { // optimize for constant access - int offset; - if (arith::GetConstInt(index, &offset)) { + if (auto* ptr = index.as()) { + int64_t offset = ptr->value; CHECK_EQ(offset % t.lanes(), 0) << "Find unaligned vector load to a vector type"; os << vid << '[' << (offset / t.lanes()) << ']'; @@ -663,9 +663,10 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) } else { CHECK(is_one(op->predicate)) << "predicated load is not supported"; - PrimExpr base; - if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) { - std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base); + + arith::PVar base; + if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) { + std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base.Eval()); HandleVolatileLoads(ref, op, os); } else { std::ostringstream svalue_expr; @@ -708,10 +709,10 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { } else { CHECK(is_one(op->predicate)) << "Predicated store is not supported"; - PrimExpr base; - if (GetRamp1Base(op->index, t.lanes(), &base)) { + arith::PVar base; + if (arith::ramp(base, 1, t.lanes()).Match(op->index)) { std::string value = this->PrintExpr(op->value); - this->PrintVecStore(op->buffer_var.get(), t, base, value); + this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value); } else { // The assignment below introduces side-effect, and the resulting value cannot // be reused across multiple expression, thus a new scope is needed diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 1d8004e9938f..7172298b2ced 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -103,11 +103,11 @@ spirv::Value CodeGenSPIRV::GetThreadIndex( spirv::Value v; if (ts.rank == 1) { v = builder_->GetLocalID(ts.dim_index); - int size = 0; - CHECK(arith::GetConstInt(extent, &size)) + auto* sizeptr = extent.as(); + CHECK(sizeptr) << "SPIRV only allows constant thread group size " << " get " << extent; CHECK_LT(ts.dim_index, 3); - workgroup_size_[ts.dim_index] = static_cast(size); + workgroup_size_[ts.dim_index] = static_cast(sizeptr->value); } else { v = builder_->GetWorkgroupID(ts.dim_index); } diff --git a/src/tir/pass/arg_binder.cc b/src/tir/pass/arg_binder.cc index c684b9e68038..76c102b5664e 100644 --- a/src/tir/pass/arg_binder.cc +++ b/src/tir/pass/arg_binder.cc @@ -289,9 +289,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, } // Byte_offset field. int data_bytes = GetVectorBytes(buffer->dtype); - int64_t const_offset; - if (arith::GetConst(buffer->elem_offset, &const_offset)) { - Bind_(make_const(DataType::UInt(64), const_offset * data_bytes), + + if (const auto* const_offset = buffer->elem_offset.as()) { + Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), arg_name + ".byte_offset", true); } else { diff --git a/src/tir/pass/ir_util.h b/src/tir/pass/ir_util.h index d8da61fdd961..a167433dd112 100644 --- a/src/tir/pass/ir_util.h +++ b/src/tir/pass/ir_util.h @@ -174,22 +174,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { return align; } -/*! - * \brief Pattern match index to Ramp with stride=1 - * This is a common pattern in continuous memory load. - * \param index The index formula - * \param lanes number of lanes in the ramp - * \param base The result base. - * \return true if pattern match success and store the base to base. - */ -inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) { - const RampNode* r = index.as(); - if (!r) return false; - if (!is_one(r->stride)) return false; - CHECK_EQ(r->lanes, lanes); - *base = r->base; - return true; -} } // namespace tir } // namespace tvm #endif // TVM_TIR_PASS_IR_UTIL_H_ diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index c70962d8207e..24747a45600c 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -57,15 +57,15 @@ class ExprTouched final : public StmtExprVisitor { } void VisitExpr_(const CallNode *op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { - int rw_mask = 0; - CHECK(arith::GetConstInt(op->args[4], &rw_mask)); + const auto* rw_mask = op->args[4].as(); const VarNode* buffer_var = op->args[1].as(); CHECK(buffer_var); + CHECK(rw_mask); // read - if (rw_mask & 1) { + if (rw_mask->value & 1) { HandleUseVar(buffer_var); } - if (rw_mask & 2) { + if (rw_mask->value & 2) { HandleWriteVar(buffer_var); } this->VisitExpr(op->args[2]); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 655a0074c7fd..abc0d9f5902c 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -163,8 +163,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { CHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction"; if (e.scope.rank == 1) { - CHECK(arith::GetConstInt(attr->value, &(e.extent))) + const auto* ptr = attr->value.as(); + CHECK(ptr) << "Need constant extent for reduce set " << iv; + e.extent = static_cast(ptr->value); if (reduce_set.count(iv->var.get())) { vred.push_back(e); ++nmatch; diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 71ba468a950f..76cfc434966d 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -30,7 +30,6 @@ #include #include "../pass/ir_util.h" -#include "../../arith/compute_expr.h" namespace tvm { namespace tir { @@ -94,11 +93,10 @@ class BuiltinLower : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); // Get constant allocation bound. - int64_t dev_type; int64_t nbytes = GetVectorBytes(op->dtype); if (device_type_.defined()) { - if (arith::GetConst(device_type_, &dev_type)) { - if (dev_type == kDLCPU) { + if (const auto* dev_type = device_type_.as()) { + if (dev_type->value == kDLCPU) { int32_t constant_size = op->constant_allocation_size(); if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { return stmt; diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 0aee3c284422..33b11a5d42c9 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -37,7 +37,7 @@ #include -#include "../pass/ir_util.h" +#include "../../arith/pattern_match.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" @@ -121,11 +121,11 @@ class WarpStoreCoeffFinder : private StmtVisitor { if (op->value.dtype().lanes() == 1) { UpdatePattern(op->index); } else { - PrimExpr base; - CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base)) + arith::PVar base; + CHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index)) << "LowerWarpMemory failed due to store index=" << op->index << ", can only handle continuous store"; - UpdatePattern(base); + UpdatePattern(base.Eval()); } } else { StmtVisitor::VisitStmt_(op); @@ -137,19 +137,18 @@ class WarpStoreCoeffFinder : private StmtVisitor { arith::DetectLinearEquation(index, {warp_index_}); CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; - int coeff = 0; PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); - - CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0) + const auto* mcoeff_as_int = mcoeff.as(); + CHECK(mcoeff_as_int && mcoeff_as_int->value > 0) << "LowerWarpMemory failed due to store index=" << index << ", require positive constant coefficient on warp index " << warp_index_ << " but get " << mcoeff; if (warp_coeff_ != 0) { - CHECK_EQ(warp_coeff_, coeff) + CHECK_EQ(warp_coeff_, mcoeff_as_int->value) << "LowerWarpMemory failed due to two different store coefficient to warp index"; } else { - warp_coeff_ = coeff; + warp_coeff_ = mcoeff_as_int->value; } } @@ -158,7 +157,7 @@ class WarpStoreCoeffFinder : private StmtVisitor { // the warp index Var warp_index_; // the coefficient - int warp_coeff_{0}; + int64_t warp_coeff_{0}; // analyzer. arith::Analyzer* analyzer_; }; @@ -184,10 +183,10 @@ class WarpIndexFinder : private StmtVisitor { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { - int value = 0; - CHECK(arith::GetConstInt(op->value, &value) && - value <= warp_size_ && - warp_size_ % value == 0) + auto* value_as_int = op->value.as(); + CHECK(value_as_int && + value_as_int->value <= warp_size_ && + warp_size_ % value_as_int->value == 0) << "Expect threadIdx.x 's size to be no larger than, and a factor of" << " warp size(" << warp_size_ << ")" << " to enable warp memory" << " but get " << op->value << " instead"; @@ -198,7 +197,7 @@ class WarpIndexFinder : private StmtVisitor { << "Please create it using thread_axis once and reuse the axis " << "across multiple binds in the same kernel"; } else { - width_ = value; + width_ = value_as_int->value; warp_index_ = iv; } } @@ -281,9 +280,12 @@ class WarpAccessRewriter : protected StmtExprMutator { // in this access pattern. std::pair SplitIndexByGroup(const PrimExpr& index) { if (index.dtype().lanes() != 1) { - PrimExpr base, local_index, group; - CHECK(GetRamp1Base(index, index.dtype().lanes(), &base)); - std::tie(local_index, group) = SplitIndexByGroup(base); + PrimExpr local_index, group; + + arith::PVar base; + CHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); + + std::tie(local_index, group) = SplitIndexByGroup(base.Eval()); local_index = RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); return std::make_pair(local_index, group); diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 99d437d9c24e..e5b2ad89ace9 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -326,13 +326,14 @@ class StorageFlattener : public StmtExprMutator { << "Prefetch dim should be the same as buffer dim"; int block_size = 1, - elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(), - shape = 0; + elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(); int starts = op->bounds.size() - 1; - while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape) - && elem_cnt >= block_size * shape) { - block_size *= shape; + + while (starts > 0) { + auto* shape_as_int = e.buffer->shape[starts].as(); + if (shape_as_int == nullptr || block_size * shape_as_int->value > elem_cnt) break; + block_size *= static_cast(shape_as_int->value); starts--; } PrimExpr stride(elem_cnt / block_size); diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 27c39d4c18aa..a8e3777904b4 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -50,16 +50,13 @@ class LoopUnroller : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { - int value = 0; - CHECK(arith::GetConstInt(op->value, &value)); + int value = static_cast(Downcast(op->value)->value); std::swap(value, auto_max_step_); Stmt ret = this->VisitStmt(op->body); std::swap(value, auto_max_step_); return ret; } else if (op->attr_key == "pragma_unroll_explicit") { - int value = 0; - CHECK(arith::GetConstInt(op->value, &value)); - bool explicit_unroll = value; + bool explicit_unroll = Downcast(op->value)->value; std::swap(explicit_unroll, explicit_unroll_); Stmt ret = this->VisitStmt(op->body); std::swap(explicit_unroll, explicit_unroll_); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index cc4361dc3ad1..22995733b31e 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -519,12 +519,11 @@ class LoopVectorizer : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (op->for_type == ForType::Vectorized) { CHECK(is_zero(op->min)); - int lanes = 0; - bool succ = arith::GetConstInt(op->extent, &lanes); - if (!succ || lanes < 1) { + auto* extent_as_int = op->extent.as(); + if (!extent_as_int || extent_as_int->value < 1) { LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; } - return Vectorizer(op->loop_var, lanes)(op->body); + return Vectorizer(op->loop_var, static_cast(extent_as_int->value))(op->body); } else { return StmtMutator::VisitStmt_(op); }