From b43a7c0abcf386ea8414e01830c09a29fda8a2a1 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Wed, 5 Nov 2025 15:32:35 +0800 Subject: [PATCH] Add LegalizeNegativeIndex pass for better negative index handling --- src/transform/legalize_negative_index.cc | 160 ++++++++++++++++++ .../test_tilelang_language_negative_index.py | 60 +++++++ tilelang/engine/phase.py | 2 + tilelang/transform/__init__.py | 11 ++ 4 files changed, 233 insertions(+) create mode 100644 src/transform/legalize_negative_index.cc create mode 100644 testing/python/language/test_tilelang_language_negative_index.py diff --git a/src/transform/legalize_negative_index.cc b/src/transform/legalize_negative_index.cc new file mode 100644 index 000000000..a1713d835 --- /dev/null +++ b/src/transform/legalize_negative_index.cc @@ -0,0 +1,160 @@ +/*! + * \file legalize_negative_index.cc + * \brief Legalize negative indices in buffer load expressions. + */ + +#include +#include +#include +#include + +#include +#include + +#include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using arith::IRVisitorWithAnalyzer; + +enum class IndexSignState { kNonNegative, kNegative, kUnknown }; + +class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { +public: + explicit NegativeIndexAnalyzer( + std::unordered_map> + *result) + : result_(result) {} + + void VisitExpr_(const BufferLoadNode *op) final { + auto load = tvm::ffi::GetRef(op); + std::vector states; + states.reserve(op->indices.size()); + bool needs_record = false; + + for (size_t i = 0; i < op->indices.size(); ++i) { + PrimExpr simplified = analyzer_.Simplify(op->indices[i]); + if (analyzer_.CanProve(simplified >= 0)) { + states.push_back(IndexSignState::kNonNegative); + continue; + } + + if (analyzer_.CanProve(simplified < 0)) { + states.push_back(IndexSignState::kNegative); + needs_record = true; + continue; + } + + states.push_back(IndexSignState::kUnknown); + needs_record = true; + LOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << load->buffer->name + << " (axis " << i << ")."; + } + + if (needs_record) { + (*result_)[op] = std::move(states); + } + + IRVisitorWithAnalyzer::VisitExpr_(op); + } + +private: + std::unordered_map> + *result_; +}; + +class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer { +public: + static PrimFunc + Apply(PrimFunc func, + const std::unordered_map> &states) { + arith::Analyzer analyzer; + NegativeIndexRewriter rewriter(&analyzer, states); + if (!func->body.defined()) { + return func; + } + PrimFuncNode *func_node = func.CopyOnWrite(); + func_node->body = rewriter.VisitStmt(func_node->body); + return func; + } + +private: + NegativeIndexRewriter( + arith::Analyzer *analyzer, + const std::unordered_map> &states) + : arith::IRMutatorWithAnalyzer(analyzer), states_(states) {} + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + BufferLoad load = + Downcast(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); + + auto it = states_.find(op); + if (it == states_.end()) { + return load; + } + + auto indices = load->indices; + bool changed = false; + + const auto &state_vector = it->second; + ICHECK_EQ(state_vector.size(), indices.size()) + << "State vector size mismatch for buffer load " << load->buffer->name; + + for (size_t i = 0; i < indices.size(); ++i) { + if (state_vector[i] != IndexSignState::kNegative) { + continue; + } + PrimExpr extent = load->buffer->shape[i]; + indices.Set(i, analyzer_->Simplify(extent + indices[i])); + changed = true; + } + + if (!changed) { + return load; + } + + return BufferLoad(load->buffer, indices); + } + + const std::unordered_map> + &states_; +}; + +PrimFunc LegalizeNegativeIndex(PrimFunc func) { + if (!func->body.defined()) { + return func; + } + + std::unordered_map> + states; + NegativeIndexAnalyzer analyzer(&states); + analyzer(func->body); + if (states.empty()) { + return func; + } + + return NegativeIndexRewriter::Apply(std::move(func), states); +} + +tvm::transform::Pass LegalizeNegativeIndexPass() { + using namespace tir::transform; + auto pass_func = [](PrimFunc f, const IRModule &, PassContext) { + return LegalizeNegativeIndex(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex", + LegalizeNegativeIndexPass); +} + +} // namespace tl +} // namespace tvm diff --git a/testing/python/language/test_tilelang_language_negative_index.py b/testing/python/language/test_tilelang_language_negative_index.py new file mode 100644 index 000000000..4a0df878b --- /dev/null +++ b/testing/python/language/test_tilelang_language_negative_index.py @@ -0,0 +1,60 @@ +from tilelang import tvm +import tilelang as tl +import tilelang.testing +from tvm.script import tir as T + + +@T.prim_func +def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"tir.noalias": True}) + B[0] = A[T.int32(-1)] + + +@T.prim_func +def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"tir.noalias": True}) + B[0] = A[T.int32(15)] + + +@T.prim_func +def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): + T.func_attr({"tir.noalias": True}) + for i in T.serial(4): + B[i] = A[-i - 1] + + +@T.prim_func +def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): + T.func_attr({"tir.noalias": True}) + for i in T.serial(4): + B[i] = A[15 - i] + + +@T.prim_func +def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), + B: T.Buffer((16,), "float32")): + T.func_attr({"tir.noalias": True}) + for i in T.serial(16): + B[i] = A[shift + i] + + +def test_legalize_negative_index_scalar(): + mod = tvm.IRModule({"main": negative_index_before}) + transformed = tl.transform.LegalizeNegativeIndex()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body) + + +def test_legalize_negative_index_affine_expr(): + mod = tvm.IRModule({"main": negative_index_loop_before}) + transformed = tl.transform.LegalizeNegativeIndex()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body) + + +def test_legalize_negative_index_symbolic_passthrough(): + mod = tvm.IRModule({"main": negative_index_symbolic_before}) + transformed = tl.transform.LegalizeNegativeIndex()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 10fd87d10..26a0bea37 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.LetInline()(mod) # Add wrapper for single buf store mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + # Normalize negative indices to canonical non-negative form + mod = tilelang.transform.LegalizeNegativeIndex()(mod) # Inject assumes to speedup tvm prover mod = tilelang.transform.InjectAssumes()(mod) # Simplify the IR expressions diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index d16a81d6e..bd305b325 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -80,6 +80,17 @@ def FrontendLegalize(): return _ffi_api.FrontendLegalize() # type: ignore +def LegalizeNegativeIndex(): + """Legalize negative indices in buffer loads. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LegalizeNegativeIndex() # type: ignore + + def InjectAssumes(): """Inject Assumes