-
Notifications
You must be signed in to change notification settings - Fork 333
[Index] Relocate Int64 Auto Promoter to ConfigBitWidth Pass, removing it from FlattenBuffer #714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fd51304
c6cf3de
34c7205
6a9b8d9
31f5e2d
824c66b
9dfa6d3
34a8487
daac9c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,5 @@ | ||||||||||||||||||||||||||||||||||||||||||
| #include "../op/builtin.h" | ||||||||||||||||||||||||||||||||||||||||||
| #include "arith/ir_mutator_with_analyzer.h" | ||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/ffi/function.h> | ||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/ffi/reflection/registry.h> | ||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/tir/builtin.h> | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -10,6 +11,7 @@ namespace tvm { | |||||||||||||||||||||||||||||||||||||||||
| namespace tl { | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| using namespace tir; | ||||||||||||||||||||||||||||||||||||||||||
| using namespace arith; | ||||||||||||||||||||||||||||||||||||||||||
| class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter { | ||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||
| using Parent = IndexDataTypeRewriter; | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -68,6 +70,92 @@ class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter { | |||||||||||||||||||||||||||||||||||||||||
| int _index_bitwidth_; | ||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| class IndexLegalizer : public IRMutatorWithAnalyzer { | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||
| static Stmt Rewrite(Stmt stmt) { | ||||||||||||||||||||||||||||||||||||||||||
| Analyzer ana; | ||||||||||||||||||||||||||||||||||||||||||
| auto pass = IndexLegalizer(&ana); | ||||||||||||||||||||||||||||||||||||||||||
| return pass.VisitStmt(stmt); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||||
| explicit IndexLegalizer(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {} | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| class Int64Promoter : public IndexDataTypeRewriter { | ||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||
| using Parent = IndexDataTypeRewriter; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| PrimExpr VisitExpr_(const VarNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||
| if (op->dtype.is_int() && op->dtype.bits() < 64) { | ||||||||||||||||||||||||||||||||||||||||||
| return cast(DataType::Int(64), GetRef<Var>(op)); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return GetRef<PrimExpr>(op); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| PrimExpr VisitExpr_(const IntImmNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||
| if (op->dtype.is_int() && op->dtype.bits() < 64) { | ||||||||||||||||||||||||||||||||||||||||||
| return IntImm(DataType::Int(64), op->value); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return GetRef<PrimExpr>(op); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| PrimExpr VisitExpr_(const CastNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||
| if (op->dtype.is_int() && op->dtype.bits() < 64) { | ||||||||||||||||||||||||||||||||||||||||||
| return cast(DataType::Int(64), op->value); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return GetRef<PrimExpr>(op); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Stmt VisitStmt_(const BufferStoreNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||
| // Force indices to be int64 | ||||||||||||||||||||||||||||||||||||||||||
| auto node = Downcast<BufferStore>(Parent::VisitStmt_(op)); | ||||||||||||||||||||||||||||||||||||||||||
| return std::move(node); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| PrimExpr VisitExpr_(const BufferLoadNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||
| auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op)); | ||||||||||||||||||||||||||||||||||||||||||
| return std::move(node); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Stmt VisitStmt_(const BufferStoreNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||
| auto buffer_store = | ||||||||||||||||||||||||||||||||||||||||||
| Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op)); | ||||||||||||||||||||||||||||||||||||||||||
| auto indices = buffer_store->indices; | ||||||||||||||||||||||||||||||||||||||||||
| for (auto index : indices) { | ||||||||||||||||||||||||||||||||||||||||||
| if (index->dtype.is_int() && index->dtype.bits() < 64) { | ||||||||||||||||||||||||||||||||||||||||||
| auto int_bound = analyzer_->const_int_bound(index); | ||||||||||||||||||||||||||||||||||||||||||
| if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 || | ||||||||||||||||||||||||||||||||||||||||||
| int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) { | ||||||||||||||||||||||||||||||||||||||||||
| Int64Promoter promoter; | ||||||||||||||||||||||||||||||||||||||||||
| index = promoter(index); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+126
to
+135
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The range-based for loop for To fix this, you should iterate by reference (
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| buffer_store.CopyOnWrite()->indices = indices; | ||||||||||||||||||||||||||||||||||||||||||
| return std::move(buffer_store); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| PrimExpr VisitExpr_(const BufferLoadNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||
| auto buffer_load = | ||||||||||||||||||||||||||||||||||||||||||
| Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op)); | ||||||||||||||||||||||||||||||||||||||||||
| auto indices = buffer_load->indices; | ||||||||||||||||||||||||||||||||||||||||||
| for (auto index : indices) { | ||||||||||||||||||||||||||||||||||||||||||
| if (index->dtype.is_int() && index->dtype.bits() < 64) { | ||||||||||||||||||||||||||||||||||||||||||
| auto int_bound = analyzer_->const_int_bound(index); | ||||||||||||||||||||||||||||||||||||||||||
| if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 || | ||||||||||||||||||||||||||||||||||||||||||
| int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) { | ||||||||||||||||||||||||||||||||||||||||||
| Int64Promoter promoter; | ||||||||||||||||||||||||||||||||||||||||||
| index = promoter(index); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+144
to
+153
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the Please change the loop to iterate by reference (
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| buffer_load.CopyOnWrite()->indices = indices; | ||||||||||||||||||||||||||||||||||||||||||
| return std::move(buffer_load); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| tvm::transform::Pass ConfigIndexBitwidth() { | ||||||||||||||||||||||||||||||||||||||||||
| using namespace tir::transform; | ||||||||||||||||||||||||||||||||||||||||||
| auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -81,6 +169,8 @@ tvm::transform::Pass ConfigIndexBitwidth() { | |||||||||||||||||||||||||||||||||||||||||
| n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)( | ||||||||||||||||||||||||||||||||||||||||||
| std::move(n->body)); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| // Legalize out-of-bound indices to be int64 | ||||||||||||||||||||||||||||||||||||||||||
| n->body = IndexLegalizer::Rewrite(std::move(n->body)); | ||||||||||||||||||||||||||||||||||||||||||
| return f; | ||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||
| return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {}); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of
VisitExpr_(const CastNode *op)does not recursively visit the expression within the cast (op->value). This can lead to incorrect behavior, especially ifop->valueis a complex expression that itself contains integers that need to be promoted. For example, an expression like(int32)a + (int32)bmight overflow before the cast toint64is applied ifaandbare not promoted first.To fix this, you should recursively call
VisitExpronop->valueto ensure all sub-expressions are properly promoted before the final cast.