Skip to content
Merged
90 changes: 90 additions & 0 deletions src/transform/config_index_bitwidth.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "../op/builtin.h"
#include "arith/ir_mutator_with_analyzer.h"
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
Expand All @@ -10,6 +11,7 @@ namespace tvm {
namespace tl {

using namespace tir;
using namespace arith;
class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter {
public:
using Parent = IndexDataTypeRewriter;
Expand Down Expand Up @@ -68,6 +70,92 @@ class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter {
int _index_bitwidth_;
};

class IndexLegalizer : public IRMutatorWithAnalyzer {

public:
static Stmt Rewrite(Stmt stmt) {
Analyzer ana;
auto pass = IndexLegalizer(&ana);
return pass.VisitStmt(stmt);
}

private:
explicit IndexLegalizer(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}

class Int64Promoter : public IndexDataTypeRewriter {
public:
using Parent = IndexDataTypeRewriter;

PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
}

PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}

PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
Comment on lines +103 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of VisitExpr_(const CastNode *op) does not recursively visit the expression within the cast (op->value). This can lead to incorrect behavior, especially if op->value is a complex expression that itself contains integers that need to be promoted. For example, an expression like (int32)a + (int32)b might overflow before the cast to int64 is applied if a and b are not promoted first.

To fix this, you should recursively call VisitExpr on op->value to ensure all sub-expressions are properly promoted before the final cast.

Suggested change
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), this->VisitExpr(op->value));
}
return Parent::VisitExpr_(op);
}


Stmt VisitStmt_(const BufferStoreNode *op) final {
// Force indices to be int64
auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
return std::move(node);
}

PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
return std::move(node);
}
};

Stmt VisitStmt_(const BufferStoreNode *op) final {
auto buffer_store =
Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto indices = buffer_store->indices;
for (auto index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index);
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter;
index = promoter(index);
}
}
}
Comment on lines +126 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The range-based for loop for indices is iterating by value (auto index), which means index is a copy of the PrimExpr from the indices array. Assigning a new value to index with index = promoter(index) only modifies the local copy, not the element within the indices array. As a result, the intended promotion of indices to int64 does not happen.

To fix this, you should iterate by reference (auto& index) to allow in-place modification of the elements in the indices array.

Suggested change
for (auto index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index);
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter;
index = promoter(index);
}
}
}
for (auto& index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index);
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter;
index = promoter(index);
}
}
}

buffer_store.CopyOnWrite()->indices = indices;
return std::move(buffer_store);
}

PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto buffer_load =
Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
auto indices = buffer_load->indices;
for (auto index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index);
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter;
index = promoter(index);
}
}
}
Comment on lines +144 to +153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the BufferStoreNode visitor, the range-based for loop for indices iterates by value. This prevents the intended promotion of indices to int64 from being applied to the indices array.

Please change the loop to iterate by reference (auto& index) to ensure the modifications are saved.

Suggested change
for (auto index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index);
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter;
index = promoter(index);
}
}
}
for (auto& index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index);
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter;
index = promoter(index);
}
}
}

buffer_load.CopyOnWrite()->indices = indices;
return std::move(buffer_load);
}
};

tvm::transform::Pass ConfigIndexBitwidth() {
using namespace tir::transform;
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
Expand All @@ -81,6 +169,8 @@ tvm::transform::Pass ConfigIndexBitwidth() {
n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)(
std::move(n->body));
}
// Legalize out-of-bound indices to be int64
n->body = IndexLegalizer::Rewrite(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
Expand Down
61 changes: 1 addition & 60 deletions src/transform/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,43 +60,6 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitStmt;
using IRMutatorWithAnalyzer::VisitStmt_;

class Int64Promoter : public tir::IndexDataTypeRewriter {
public:
using Parent = IndexDataTypeRewriter;

PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
}

PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}

PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}

Stmt VisitStmt_(const BufferStoreNode *op) final {
// Force indices to be int64
auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
return std::move(node);
}

PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
return std::move(node);
}
};

explicit BufferFlattener(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}

Stmt VisitStmt_(const BlockNode *op) final {
Expand Down Expand Up @@ -277,29 +240,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer,
const Array<PrimExpr> &indices) {
auto flattened_indices = buffer->ElemOffset(indices);
Array<PrimExpr> safe_indices;
for (auto index : flattened_indices) {
auto int_bound = analyzer_->const_int_bound(index);
DataType dtype = index->dtype;
if (dtype.is_int() && dtype.bits() < 64) {
int64_t max_value = int_bound->max_value;
int64_t min_value = int_bound->min_value;
const int64_t type_max = (1LL << (dtype.bits() - 1));
const int64_t type_min = -(1LL << (dtype.bits() - 1));

if (max_value >= (type_max - 1) || min_value < type_min) {
Int64Promoter promoter;
for (auto &index : flattened_indices) {
safe_indices.push_back(promoter(index));
}
} else {
safe_indices.push_back(index);
}
} else {
safe_indices.push_back(index);
}
}
return this->IterMapSimplifyWithContext(safe_indices, false);
return this->IterMapSimplifyWithContext(flattened_indices, false);
}

template <typename Node> Node VisitBufferAccess(Node node) {
Expand Down