Skip to content

Commit a961173

Browse files
authored
[Index] Relocate Int64 Auto Promoter to ConfigBitWidth Pass, removing it from FlattenBuffer (#714)
* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107 * Refactor inject_pipeline.cc to enhance pipeline body rewriting and condition handling - Introduced a new function to replace IfThenElse nodes with their then_case while preserving attributes. - Streamlined the PipelineBodyRewriter to improve buffer access rewriting and async state management. - Enhanced the handling of pipeline loop conditions and added support for predicate conditions in the pipeline body. - Removed obsolete code and improved overall code clarity and maintainability. * lint fix * Refactor return statements in inject_pipeline.cc to remove unnecessary std::move calls - Updated return statements in multiple methods to return objects directly instead of using std::move, improving code clarity and potentially avoiding unnecessary moves. - Ensured consistent handling of BufferStore and BufferLoad nodes during pipeline transformations. * test fix * Enhance global read detection in pipeline planning - Updated the handling of global reads to account for condition expressions within IfThenElse nodes, ensuring accurate identification of global memory accesses. - Introduced a new flag to track whether the visitor is within a condition expression, improving the correctness of buffer access analysis. - Refactored the VisitStmt_ method to properly handle the structure of IfThenElse nodes, enhancing the clarity and maintainability of the code. * Add IndexLegalizer to enforce int64 for out-of-bound indices - Introduced the IndexLegalizer class to ensure that indices in BufferStore and BufferLoad nodes are promoted to int64 when they exceed their type bounds. - Refactored the Int64Promoter logic from flatten_buffer.cc into IndexLegalizer, improving code organization and reusability. - Updated the ConfigIndexBitwidth pass to apply IndexLegalizer after rewriting the body, enhancing the handling of index bitwidths in transformations.
1 parent c1eef51 commit a961173

File tree

2 files changed

+91
-60
lines changed

2 files changed

+91
-60
lines changed

src/transform/config_index_bitwidth.cc

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "../op/builtin.h"
2+
#include "arith/ir_mutator_with_analyzer.h"
23
#include <tvm/ffi/function.h>
34
#include <tvm/ffi/reflection/registry.h>
45
#include <tvm/tir/builtin.h>
@@ -10,6 +11,7 @@ namespace tvm {
1011
namespace tl {
1112

1213
using namespace tir;
14+
using namespace arith;
1315
class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter {
1416
public:
1517
using Parent = IndexDataTypeRewriter;
@@ -68,6 +70,92 @@ class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter {
6870
int _index_bitwidth_;
6971
};
7072

73+
class IndexLegalizer : public IRMutatorWithAnalyzer {
74+
75+
public:
76+
static Stmt Rewrite(Stmt stmt) {
77+
Analyzer ana;
78+
auto pass = IndexLegalizer(&ana);
79+
return pass.VisitStmt(stmt);
80+
}
81+
82+
private:
83+
explicit IndexLegalizer(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}
84+
85+
class Int64Promoter : public IndexDataTypeRewriter {
86+
public:
87+
using Parent = IndexDataTypeRewriter;
88+
89+
PrimExpr VisitExpr_(const VarNode *op) final {
90+
if (op->dtype.is_int() && op->dtype.bits() < 64) {
91+
return cast(DataType::Int(64), GetRef<Var>(op));
92+
}
93+
return GetRef<PrimExpr>(op);
94+
}
95+
96+
PrimExpr VisitExpr_(const IntImmNode *op) final {
97+
if (op->dtype.is_int() && op->dtype.bits() < 64) {
98+
return IntImm(DataType::Int(64), op->value);
99+
}
100+
return GetRef<PrimExpr>(op);
101+
}
102+
103+
PrimExpr VisitExpr_(const CastNode *op) final {
104+
if (op->dtype.is_int() && op->dtype.bits() < 64) {
105+
return cast(DataType::Int(64), op->value);
106+
}
107+
return GetRef<PrimExpr>(op);
108+
}
109+
110+
Stmt VisitStmt_(const BufferStoreNode *op) final {
111+
// Force indices to be int64
112+
auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
113+
return std::move(node);
114+
}
115+
116+
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
117+
auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
118+
return std::move(node);
119+
}
120+
};
121+
122+
Stmt VisitStmt_(const BufferStoreNode *op) final {
123+
auto buffer_store =
124+
Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
125+
auto indices = buffer_store->indices;
126+
for (auto index : indices) {
127+
if (index->dtype.is_int() && index->dtype.bits() < 64) {
128+
auto int_bound = analyzer_->const_int_bound(index);
129+
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
130+
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
131+
Int64Promoter promoter;
132+
index = promoter(index);
133+
}
134+
}
135+
}
136+
buffer_store.CopyOnWrite()->indices = indices;
137+
return std::move(buffer_store);
138+
}
139+
140+
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
141+
auto buffer_load =
142+
Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
143+
auto indices = buffer_load->indices;
144+
for (auto index : indices) {
145+
if (index->dtype.is_int() && index->dtype.bits() < 64) {
146+
auto int_bound = analyzer_->const_int_bound(index);
147+
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
148+
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
149+
Int64Promoter promoter;
150+
index = promoter(index);
151+
}
152+
}
153+
}
154+
buffer_load.CopyOnWrite()->indices = indices;
155+
return std::move(buffer_load);
156+
}
157+
};
158+
71159
tvm::transform::Pass ConfigIndexBitwidth() {
72160
using namespace tir::transform;
73161
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
@@ -81,6 +169,8 @@ tvm::transform::Pass ConfigIndexBitwidth() {
81169
n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)(
82170
std::move(n->body));
83171
}
172+
// Legalize out-of-bound indices to be int64
173+
n->body = IndexLegalizer::Rewrite(std::move(n->body));
84174
return f;
85175
};
86176
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});

src/transform/flatten_buffer.cc

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -60,43 +60,6 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
6060
using IRMutatorWithAnalyzer::VisitStmt;
6161
using IRMutatorWithAnalyzer::VisitStmt_;
6262

63-
class Int64Promoter : public tir::IndexDataTypeRewriter {
64-
public:
65-
using Parent = IndexDataTypeRewriter;
66-
67-
PrimExpr VisitExpr_(const VarNode *op) final {
68-
if (op->dtype.is_int() && op->dtype.bits() < 64) {
69-
return cast(DataType::Int(64), GetRef<Var>(op));
70-
}
71-
return GetRef<PrimExpr>(op);
72-
}
73-
74-
PrimExpr VisitExpr_(const IntImmNode *op) final {
75-
if (op->dtype.is_int() && op->dtype.bits() < 64) {
76-
return IntImm(DataType::Int(64), op->value);
77-
}
78-
return GetRef<PrimExpr>(op);
79-
}
80-
81-
PrimExpr VisitExpr_(const CastNode *op) final {
82-
if (op->dtype.is_int() && op->dtype.bits() < 64) {
83-
return cast(DataType::Int(64), op->value);
84-
}
85-
return GetRef<PrimExpr>(op);
86-
}
87-
88-
Stmt VisitStmt_(const BufferStoreNode *op) final {
89-
// Force indices to be int64
90-
auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
91-
return std::move(node);
92-
}
93-
94-
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
95-
auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
96-
return std::move(node);
97-
}
98-
};
99-
10063
explicit BufferFlattener(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}
10164

10265
Stmt VisitStmt_(const BlockNode *op) final {
@@ -277,29 +240,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
277240
Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer,
278241
const Array<PrimExpr> &indices) {
279242
auto flattened_indices = buffer->ElemOffset(indices);
280-
Array<PrimExpr> safe_indices;
281-
for (auto index : flattened_indices) {
282-
auto int_bound = analyzer_->const_int_bound(index);
283-
DataType dtype = index->dtype;
284-
if (dtype.is_int() && dtype.bits() < 64) {
285-
int64_t max_value = int_bound->max_value;
286-
int64_t min_value = int_bound->min_value;
287-
const int64_t type_max = (1LL << (dtype.bits() - 1));
288-
const int64_t type_min = -(1LL << (dtype.bits() - 1));
289-
290-
if (max_value >= (type_max - 1) || min_value < type_min) {
291-
Int64Promoter promoter;
292-
for (auto &index : flattened_indices) {
293-
safe_indices.push_back(promoter(index));
294-
}
295-
} else {
296-
safe_indices.push_back(index);
297-
}
298-
} else {
299-
safe_indices.push_back(index);
300-
}
301-
}
302-
return this->IterMapSimplifyWithContext(safe_indices, false);
243+
return this->IterMapSimplifyWithContext(flattened_indices, false);
303244
}
304245

305246
template <typename Node> Node VisitBufferAccess(Node node) {

0 commit comments

Comments
 (0)