Skip to content

Commit

Permalink
[ARITH] Remove legacy const pattern functions (apache#5387)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and dpankratz committed Apr 24, 2020
1 parent 9af9168 commit e132a55
Show file tree
Hide file tree
Showing 14 changed files with 87 additions and 112 deletions.
21 changes: 0 additions & 21 deletions src/arith/compute_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,6 @@ template<typename Op>
inline PrimExpr ComputeReduce(
const Array<PrimExpr>& values, PrimExpr empty_value);

inline bool GetConst(PrimExpr e, int64_t* out) {
if (e.dtype().is_vector()) return false;
const int64_t* v = tir::as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}

// get a small constant int
inline bool GetConstInt(PrimExpr e, int* out) {
int64_t v1 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
return false;
}

template<>
inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
return a + b;
Expand Down
11 changes: 11 additions & 0 deletions src/arith/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,17 @@ ramp(const Pattern<TBase>& base,
base.derived(), stride.derived(), lanes.derived());
}

template<typename TBase>
inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>
ramp(const Pattern<TBase>& base,
int stride,
int lanes) {
return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>(
base.derived(),
PConstWithTypeLike<TBase>(base.derived(), stride),
PConst<int>(lanes));
}

/*!
* \brief Pattern broadcast expression.
* \tparam TA The pattern type of the value.
Expand Down
39 changes: 20 additions & 19 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "codegen_llvm.h"
#include "codegen_cpu.h"
#include "../../arith/pattern_match.h"
#include "../build_common.h"
namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -363,27 +364,27 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
md_builder_->createTBAAStructTagNode(meta, meta, 0));
return;
}
int base = 0, width = 0;

int64_t base = 0, width = 0;
arith::PVar<IntImm> pbase, pstride;
arith::PVar<int> planes;
// create meta-data for alias analysis
// Use a group of binary tree ranges of memory banks.
if (index.defined()) {
const RampNode* ramp = index.as<RampNode>();
if (ramp) {
int base, stride;
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
}
if (arith::ramp(pbase, pstride, planes).Match(index)) {
base = pbase.Eval()->value;
int64_t xwith = planes.Eval() * pstride.Eval()->value;
width = 1;
while (width < xwith) {
width *= 2;
}
} else {
if (arith::GetConstInt(index, &base)) width = 1;
while (base % width) {
base -= base % width;
width *= 2;
}
} else if (auto* ptr = index.as<tir::IntImmNode>()) {
width = 1;
base = ptr->value;
}
}
llvm::MDNode* meta = md_tbaa_root_;
Expand All @@ -394,8 +395,8 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int w = 1024; w >= width; w /= 2) {
int b = (base / w) * w;
for (int64_t w = 1024; w >= width; w /= 2) {
int64_t b = (base / w) * w;
std::stringstream os;
os << buffer << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
Expand Down
19 changes: 10 additions & 9 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
#include <iomanip>
#include <cctype>
#include "codegen_c.h"
#include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h"
#include "../../tir/pass/ir_util.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -198,8 +198,8 @@ std::string CodeGenC::GetBufferRef(
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t) && !is_vol) {
// optimize for constant access
int offset;
if (arith::GetConstInt(index, &offset)) {
if (auto* ptr = index.as<tir::IntImmNode>()) {
int64_t offset = ptr->value;
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
Expand Down Expand Up @@ -663,9 +663,10 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
} else {
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
PrimExpr base;
if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);

arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) {
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
} else {
std::ostringstream svalue_expr;
Expand Down Expand Up @@ -708,10 +709,10 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
} else {
CHECK(is_one(op->predicate))
<< "Predicated store is not supported";
PrimExpr base;
if (GetRamp1Base(op->index, t.lanes(), &base)) {
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base, value);
this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
} else {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
Expand Down
6 changes: 3 additions & 3 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
spirv::Value v;
if (ts.rank == 1) {
v = builder_->GetLocalID(ts.dim_index);
int size = 0;
CHECK(arith::GetConstInt(extent, &size))
auto* sizeptr = extent.as<tir::IntImmNode>();
CHECK(sizeptr)
<< "SPIRV only allows constant thread group size " << " get " << extent;
CHECK_LT(ts.dim_index, 3);
workgroup_size_[ts.dim_index] = static_cast<uint32_t>(size);
workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
} else {
v = builder_->GetWorkgroupID(ts.dim_index);
}
Expand Down
6 changes: 3 additions & 3 deletions src/tir/pass/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
int64_t const_offset;
if (arith::GetConst(buffer->elem_offset, &const_offset)) {
Bind_(make_const(DataType::UInt(64), const_offset * data_bytes),

if (const auto* const_offset = buffer->elem_offset.as<IntImmNode>()) {
Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
Expand Down
16 changes: 0 additions & 16 deletions src/tir/pass/ir_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
return align;
}

/*!
* \brief Pattern match index to Ramp with stride=1
* This is a common pattern in continuous memory load.
* \param index The index formula
* \param lanes number of lanes in the ramp
* \param base The result base.
* \return true if pattern match success and store the base to base.
*/
inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) {
const RampNode* r = index.as<RampNode>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_PASS_IR_UTIL_H_
8 changes: 4 additions & 4 deletions src/tir/transforms/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ class ExprTouched final : public StmtExprVisitor {
}
void VisitExpr_(const CallNode *op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
int rw_mask = 0;
CHECK(arith::GetConstInt(op->args[4], &rw_mask));
const auto* rw_mask = op->args[4].as<IntImmNode>();
const VarNode* buffer_var = op->args[1].as<VarNode>();
CHECK(buffer_var);
CHECK(rw_mask);
// read
if (rw_mask & 1) {
if (rw_mask->value & 1) {
HandleUseVar(buffer_var);
}
if (rw_mask & 2) {
if (rw_mask->value & 2) {
HandleWriteVar(buffer_var);
}
this->VisitExpr(op->args[2]);
Expand Down
4 changes: 3 additions & 1 deletion src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
CHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
CHECK(arith::GetConstInt(attr->value, &(e.extent)))
const auto* ptr = attr->value.as<IntImmNode>();
CHECK(ptr)
<< "Need constant extent for reduce set " << iv;
e.extent = static_cast<int>(ptr->value);
if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
Expand Down
6 changes: 2 additions & 4 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <unordered_set>

#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -94,11 +93,10 @@ class BuiltinLower : public StmtExprMutator {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
// Get constant allocation bound.
int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->dtype);
if (device_type_.defined()) {
if (arith::GetConst(device_type_, &dev_type)) {
if (dev_type == kDLCPU) {
if (const auto* dev_type = device_type_.as<IntImmNode>()) {
if (dev_type->value == kDLCPU) {
int32_t constant_size = op->constant_allocation_size();
if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
return stmt;
Expand Down
38 changes: 20 additions & 18 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

#include <unordered_set>

#include "../pass/ir_util.h"
#include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"

Expand Down Expand Up @@ -121,11 +121,11 @@ class WarpStoreCoeffFinder : private StmtVisitor {
if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
} else {
PrimExpr base;
CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base))
arith::PVar<PrimExpr> base;
CHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index))
<< "LowerWarpMemory failed due to store index=" << op->index
<< ", can only handle continuous store";
UpdatePattern(base);
UpdatePattern(base.Eval());
}
} else {
StmtVisitor::VisitStmt_(op);
Expand All @@ -137,19 +137,18 @@ class WarpStoreCoeffFinder : private StmtVisitor {
arith::DetectLinearEquation(index, {warp_index_});
CHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed due to store index=" << index;
int coeff = 0;
PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);

CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0)
const auto* mcoeff_as_int = mcoeff.as<IntImmNode>();
CHECK(mcoeff_as_int && mcoeff_as_int->value > 0)
<< "LowerWarpMemory failed due to store index=" << index
<< ", require positive constant coefficient on warp index " << warp_index_
<< " but get " << mcoeff;

if (warp_coeff_ != 0) {
CHECK_EQ(warp_coeff_, coeff)
CHECK_EQ(warp_coeff_, mcoeff_as_int->value)
<< "LowerWarpMemory failed due to two different store coefficient to warp index";
} else {
warp_coeff_ = coeff;
warp_coeff_ = mcoeff_as_int->value;
}
}

Expand All @@ -158,7 +157,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
// the warp index
Var warp_index_;
// the coefficient
int warp_coeff_{0};
int64_t warp_coeff_{0};
// analyzer.
arith::Analyzer* analyzer_;
};
Expand All @@ -184,10 +183,10 @@ class WarpIndexFinder : private StmtVisitor {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value) &&
value <= warp_size_ &&
warp_size_ % value == 0)
auto* value_as_int = op->value.as<IntImmNode>();
CHECK(value_as_int &&
value_as_int->value <= warp_size_ &&
warp_size_ % value_as_int->value == 0)
<< "Expect threadIdx.x 's size to be no larger than, and a factor of"
<< " warp size(" << warp_size_ << ")" << " to enable warp memory"
<< " but get " << op->value << " instead";
Expand All @@ -198,7 +197,7 @@ class WarpIndexFinder : private StmtVisitor {
<< "Please create it using thread_axis once and reuse the axis "
<< "across multiple binds in the same kernel";
} else {
width_ = value;
width_ = value_as_int->value;
warp_index_ = iv;
}
}
Expand Down Expand Up @@ -281,9 +280,12 @@ class WarpAccessRewriter : protected StmtExprMutator {
// in this access pattern.
std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) {
if (index.dtype().lanes() != 1) {
PrimExpr base, local_index, group;
CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
std::tie(local_index, group) = SplitIndexByGroup(base);
PrimExpr local_index, group;

arith::PVar<PrimExpr> base;
CHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index));

std::tie(local_index, group) = SplitIndexByGroup(base.Eval());
local_index =
RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
return std::make_pair(local_index, group);
Expand Down
11 changes: 6 additions & 5 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,14 @@ class StorageFlattener : public StmtExprMutator {
<< "Prefetch dim should be the same as buffer dim";

int block_size = 1,
elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(),
shape = 0;
elem_cnt = cache_line_size_ / e.buffer->dtype.bytes();

int starts = op->bounds.size() - 1;
while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
&& elem_cnt >= block_size * shape) {
block_size *= shape;

while (starts > 0) {
auto* shape_as_int = e.buffer->shape[starts].as<IntImmNode>();
if (shape_as_int == nullptr || block_size * shape_as_int->value > elem_cnt) break;
block_size *= static_cast<int>(shape_as_int->value);
starts--;
}
PrimExpr stride(elem_cnt / block_size);
Expand Down
7 changes: 2 additions & 5 deletions src/tir/transforms/unroll_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,13 @@ class LoopUnroller : public StmtExprMutator {

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value));
int value = static_cast<int>(Downcast<Integer>(op->value)->value);
std::swap(value, auto_max_step_);
Stmt ret = this->VisitStmt(op->body);
std::swap(value, auto_max_step_);
return ret;
} else if (op->attr_key == "pragma_unroll_explicit") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value));
bool explicit_unroll = value;
bool explicit_unroll = Downcast<Integer>(op->value)->value;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->VisitStmt(op->body);
std::swap(explicit_unroll, explicit_unroll_);
Expand Down
Loading

0 comments on commit e132a55

Please sign in to comment.