From 0a1aac68eb132755ea27d7e066c6a3f3c1287e90 Mon Sep 17 00:00:00 2001
From: ZhouXin <zhou.xin@mail.ustc.edu.cn>
Date: Wed, 25 Dec 2024 08:29:25 +0000
Subject: [PATCH 1/8] Update longlong2int pass

---
 paddle/cinn/optim/CMakeLists.txt           |   2 +-
 paddle/cinn/optim/longlong2int.cc          | 191 -----------------
 paddle/cinn/optim/longlong2int.h           |  24 ---
 paddle/cinn/optim/longlong2int_pass.cc     | 230 +++++++++++++++++++++
 paddle/cinn/optim/longlong2int_pass.h      | 104 ++++++++++
 paddle/cinn/optim/transform_gpu_forloop.cc |  14 +-
 6 files changed, 347 insertions(+), 218 deletions(-)
 delete mode 100644 paddle/cinn/optim/longlong2int.cc
 delete mode 100644 paddle/cinn/optim/longlong2int.h
 create mode 100644 paddle/cinn/optim/longlong2int_pass.cc
 create mode 100644 paddle/cinn/optim/longlong2int_pass.h

diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt
index 6d2ae9b159df89..92682e90b79240 100755
--- a/paddle/cinn/optim/CMakeLists.txt
+++ b/paddle/cinn/optim/CMakeLists.txt
@@ -38,7 +38,7 @@ gather_srcs(
   eliminate_common_global_memory_read.cc
   rearrange_load_instruction.cc
   check_tensor_buffer_map.cc
-  longlong2int.cc
+  longlong2int_pass.cc
   vectorize_for_trans.cc)
 
 if(WITH_CUDA OR WITH_ROCM)
diff --git a/paddle/cinn/optim/longlong2int.cc b/paddle/cinn/optim/longlong2int.cc
deleted file mode 100644
index de158332ac1ef9..00000000000000
--- a/paddle/cinn/optim/longlong2int.cc
+++ /dev/null
@@ -1,191 +0,0 @@
-// Copyright (c) 2024 CINN Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "paddle/cinn/optim/longlong2int.h"
-#include "paddle/cinn/ir/ir_mutator.h"
-#include "paddle/cinn/ir/ir_printer.h"
-#include "paddle/cinn/ir/ir_utils.h"
-#include "paddle/cinn/ir/ir_visitor.h"
-
-namespace cinn {
-namespace optim {
-
-class CheckOverflow : public ir::IRVisitor {
- public:
-  bool is_overflow(Expr* expr) {
-    ir::IRVisitor::Visit(expr);
-    return is_overflow_;
-  }
-
- private:
-  void Visit(const ir::For* for_op) override {
-    if (!for_op->extent.is_constant()) is_overflow_ = true;
-    if (!for_op->extent.type().is_index_type()) is_overflow_ = true;
-    if (curr_product_ > INT_MAX) is_overflow_ = true;
-
-    if (is_overflow_) return;
-
-    curr_product_ *= for_op->extent.as_int64();
-    ir::IRVisitor::Visit(&for_op->body);
-    curr_product_ /= for_op->extent.as_int64();
-  }
-  void Visit(const ir::ScheduleBlock* op) override {
-    ir::IRVisitor::Visit(&(op->body));
-  }
-  void Visit(const ir::ScheduleBlockRealize* op) override {
-    ir::IRVisitor::Visit(&(op->schedule_block));
-  }
-  void Visit(const ir::Block* op) {
-    for (auto& expr : op->stmts) {
-      ir::IRVisitor::Visit(&expr);
-    }
-  }
-  void Visit(const ir::IfThenElse* op) {
-    ir::IRVisitor::Visit(&(op->true_case));
-    if (op->false_case.defined()) ir::IRVisitor::Visit(&(op->false_case));
-  }
-  int64_t curr_product_ = 1;
-  bool is_overflow_ = false;
-};
-
-class CastLonglong2Int : public ir::IRMutator<> {
- public:
-  void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
-
- private:
-  void Visit(const ir::_Tensor_* op, Expr* expr) override {
-    auto node = expr->As<ir::_Tensor_>();
-    std::for_each(node->shape.begin(),
-                  node->shape.end(),
-                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
-    CastBufferMeta(node->buffer);
-  }
-  void Visit(const ir::Load* op, Expr* expr) override {
-    auto node = expr->As<ir::Load>();
-    std::for_each(node->indices.begin(),
-                  node->indices.end(),
-                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
-
-    ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
-  }
-  void Visit(const ir::Store* op, Expr* expr) override {
-    auto node = expr->As<ir::Store>();
-    std::for_each(node->indices.begin(),
-                  node->indices.end(),
-                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
-    ir::IRMutator<>::Visit(&node->value, &node->value);
-    ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
-  }
-  void Visit(const ir::IfThenElse* op, Expr* expr) override {
-    auto node = expr->As<ir::IfThenElse>();
-    auto cond = node->condition;
-    if (cond.is_cmp()) {
-      if (cond->operand(0).is_index())
-        cond->operand(0)->convert_int64_to_int32();
-      if (cond->operand(1).is_index())
-        cond->operand(1)->convert_int64_to_int32();
-    }
-    ir::IRMutator<>::Visit(&node->true_case, &node->true_case);
-    if (node->false_case.defined()) {
-      ir::IRMutator<>::Visit(&node->false_case, &node->false_case);
-    }
-  }
-  void Visit(const ir::Select* op, Expr* expr) override {
-    auto node = expr->As<ir::Select>();
-    auto cond = node->condition;
-    if (cond.is_cmp()) {
-      if (cond->operand(0).is_index())
-        cond->operand(0)->convert_int64_to_int32();
-      if (cond->operand(1).is_index())
-        cond->operand(1)->convert_int64_to_int32();
-    }
-    ir::IRMutator<>::Visit(&node->true_value, &node->true_value);
-    ir::IRMutator<>::Visit(&node->false_value, &node->false_value);
-  }
-  void Visit(const ir::For* op, Expr* expr) override {
-    auto node = expr->As<ir::For>();
-    CastVarWithBound(node->loop_var);
-    node->min->convert_int64_to_int32();
-    node->extent->convert_int64_to_int32();
-    ir::IRMutator<>::Visit(&node->body, &node->body);
-  }
-  void Visit(const ir::ScheduleBlock* op, Expr* expr) override {
-    auto* node = expr->As<ir::ScheduleBlock>();
-
-    std::for_each(node->iter_vars.begin(),
-                  node->iter_vars.end(),
-                  [&](cinn::ir::Var& v) { CastVarWithBound(v); });
-
-    for (auto& buffer_range : node->read_buffers) {
-      if (auto range = buffer_range.As<ir::_BufferRange_>()) {
-        std::for_each(range->ranges.begin(),
-                      range->ranges.end(),
-                      [&](cinn::ir::Var& v) { CastVarWithBound(v); });
-        auto bf = range->buffer.as_buffer_ref();
-        CastBufferMeta(bf);
-      }
-    }
-
-    for (auto& buffer_range : node->write_buffers) {
-      if (auto range = buffer_range.As<ir::_BufferRange_>()) {
-        std::for_each(range->ranges.begin(),
-                      range->ranges.end(),
-                      [&](cinn::ir::Var& v) { CastVarWithBound(v); });
-        auto bf = range->buffer.as_buffer_ref();
-        CastBufferMeta(bf);
-      }
-    }
-    ir::IRMutator<>::Visit(&(node->body), &(node->body));
-  }
-
-  void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override {
-    auto* node = expr->As<ir::ScheduleBlockRealize>();
-
-    std::for_each(node->iter_values.begin(),
-                  node->iter_values.end(),
-                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
-    ir::IRMutator<>::Visit(&node->schedule_block, &node->schedule_block);
-  }
-
-  void CastVarWithBound(cinn::ir::Var& var) {  // NOLINT
-    if (!var.defined()) return;
-    var->convert_int64_to_int32();
-    auto lb = var->lower_bound;
-    auto ub = var->upper_bound;
-    if (lb.defined()) lb->convert_int64_to_int32();
-    if (ub.defined()) ub->convert_int64_to_int32();
-  }
-  void CastBufferMeta(cinn::ir::Buffer& bf) {  // NOLINT
-    if (!bf.defined()) return;
-    std::for_each(bf->shape.begin(), bf->shape.end(), [&](cinn::ir::Expr& e) {
-      e->convert_int64_to_int32();
-    });
-    std::for_each(bf->strides.begin(),
-                  bf->strides.end(),
-                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
-    bf->elem_offset->convert_int64_to_int32();
-  }
-};
-
-void TryCastLonglong2Int(Expr* expr) {
-  VLOG(6) << "Before TryCastLonglong2Int, Expr = \n" << *expr;
-  CheckOverflow check_overflow;
-  if (!check_overflow.is_overflow(expr)) {
-    CastLonglong2Int narrow;
-    narrow(expr);
-  }
-  VLOG(6) << "After TryCastLonglong2Int, Expr = \n" << *expr;
-}
-}  // namespace optim
-}  // namespace cinn
diff --git a/paddle/cinn/optim/longlong2int.h b/paddle/cinn/optim/longlong2int.h
deleted file mode 100644
index b72e70df603a82..00000000000000
--- a/paddle/cinn/optim/longlong2int.h
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright (c) 2024 CINN Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#pragma once
-#include "paddle/cinn/ir/ir.h"
-
-namespace cinn {
-namespace optim {
-
-// Try to change the type of longlong to int in the expr.
-void TryCastLonglong2Int(Expr* expr);
-}  // namespace optim
-}  // namespace cinn
diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc
new file mode 100644
index 00000000000000..734878345d46c3
--- /dev/null
+++ b/paddle/cinn/optim/longlong2int_pass.cc
@@ -0,0 +1,230 @@
+// Copyright (c) 2024 CINN Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/cinn/optim/longlong2int_pass.h"
+#include "paddle/cinn/ir/ir_mutator.h"
+#include "paddle/cinn/ir/ir_printer.h"
+#include "paddle/cinn/ir/ir_utils.h"
+#include "paddle/cinn/ir/ir_visitor.h"
+#include "paddle/cinn/ir/stmt.h"
+#include "paddle/cinn/ir/stmt_visitors.h"
+
+namespace cinn {
+namespace optim {
+namespace {
+using ir::stmt::BlockRef;
+using ir::stmt::For;
+using ir::stmt::IfThenElse;
+using ir::stmt::Schedule;
+using ir::stmt::StmtRef;
+using ir::stmt::Store;
+
+class CheckOverflow : public ir::stmt::StmtVisitor<> {
+ public:
+  bool operator()(const StmtRef& stmt) {
+    VisitStmt(stmt);
+    return is_overflow_;
+  }
+  bool operator()(const BlockRef& block) {
+    VisitBlock(block);
+    return is_overflow_;
+  }
+
+ private:
+  void VisitStmt(const StmtRef& stmt) override {
+    if (is_overflow_) return;
+    ir::stmt::StmtVisitor<>::VisitStmt(stmt);
+  }
+
+  void VisitStmt(const For& for_stmt) override {
+    if (!for_stmt->extent().is_constant()) is_overflow_ = true;
+    if (!for_stmt->extent().type().is_index_type()) is_overflow_ = true;
+    if (curr_product_ > INT_MAX) is_overflow_ = true;
+
+    if (is_overflow_) return;
+
+    curr_product_ *= for_stmt->extent().as_int64();
+    VisitBlock(for_stmt->body());
+    curr_product_ /= for_stmt->extent().as_int64();
+  }
+
+  void VisitStmt(const Schedule& schedule_stmt) override {
+    VisitBlock(schedule_stmt->body());
+  }
+
+  void VisitStmt(const IfThenElse& stmt) override {
+    VisitBlock(stmt->true_case());
+    if (stmt->false_case().defined()) {
+      VisitBlock(stmt->false_case());
+    }
+  }
+
+  void VisitStmt(const ir::stmt::Let& stmt) override { return; }
+  void VisitStmt(const ir::stmt::Store& stmt) override { return; }
+  void VisitStmt(const ir::stmt::Alloc& stmt) override { return; }
+  void VisitStmt(const ir::stmt::Free& stmt) override { return; }
+  void VisitStmt(const ir::stmt::Evaluate& stmt) override { return; }
+
+ private:
+  int64_t curr_product_ = 1;
+  bool is_overflow_ = false;
+};
+
+class CastLonglong2Int : public ir::IRMutator<>,
+                         public ir::stmt::StmtMutator<> {
+ public:
+  void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
+  void operator()(StmtRef stmt) { ir::stmt::StmtMutator<>::VisitStmt(stmt); }
+  void operator()(BlockRef block) {
+    ir::stmt::StmtMutator<>::VisitBlock(block);
+  }
+
+ private:
+  void Visit(const ir::_Tensor_* op, Expr* expr) override {
+    auto node = expr->As<ir::_Tensor_>();
+    std::for_each(node->shape.begin(),
+                  node->shape.end(),
+                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
+    CastBufferMeta(node->buffer);
+  }
+  void Visit(const ir::Load* op, Expr* expr) override {
+    auto node = expr->As<ir::Load>();
+    std::for_each(node->indices.begin(),
+                  node->indices.end(),
+                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
+
+    ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
+  }
+  void Visit(const ir::Select* op, Expr* expr) override {
+    auto node = expr->As<ir::Select>();
+    auto cond = node->condition;
+    if (cond.is_cmp()) {
+      if (cond->operand(0).is_index())
+        cond->operand(0)->convert_int64_to_int32();
+      if (cond->operand(1).is_index())
+        cond->operand(1)->convert_int64_to_int32();
+    }
+    ir::IRMutator<>::Visit(&node->true_value, &node->true_value);
+    ir::IRMutator<>::Visit(&node->false_value, &node->false_value);
+  }
+  void VisitStmt(Store stmt) override {
+    std::vector<Expr> indices = stmt->indices();
+    std::for_each(indices.begin(), indices.end(), [&](cinn::ir::Expr& e) {
+      e->convert_int64_to_int32();
+    });
+    Expr value = stmt->value();
+    Expr tensor = stmt->tensor();
+    ir::IRMutator<>::Visit(&value, &value);
+    ir::IRMutator<>::Visit(&tensor, &tensor);
+  }
+  void VisitStmt(IfThenElse stmt) override {
+    Expr cond = stmt->condition();
+    if (cond.is_cmp()) {
+      if (cond->operand(0).is_index())
+        cond->operand(0)->convert_int64_to_int32();
+      if (cond->operand(1).is_index())
+        cond->operand(1)->convert_int64_to_int32();
+    }
+    ir::stmt::StmtMutator<>::VisitBlock(stmt->true_case());
+    if (stmt->false_case().defined()) {
+      ir::stmt::StmtMutator<>::VisitBlock(stmt->false_case());
+    }
+  }
+  void VisitStmt(For stmt) override {
+    ir::Var loop_var = stmt->loop_var();
+    CastVarWithBound(loop_var);
+    stmt->set_loop_var(loop_var);
+    stmt->min()->convert_int64_to_int32();
+    stmt->extent()->convert_int64_to_int32();
+    ir::stmt::StmtMutator<>::VisitBlock(stmt->body());
+  }
+  void VisitStmt(Schedule stmt) override {
+    std::vector<Var> iter_vars = stmt->iter_vars();
+    std::for_each(iter_vars.begin(), iter_vars.end(), [&](cinn::ir::Var& v) {
+      CastVarWithBound(v);
+    });
+
+    for (auto& buffer_range : stmt->read_buffers()) {
+      if (auto range = buffer_range.As<ir::_BufferRange_>()) {
+        std::vector<Var> ranges = range->ranges;
+        std::for_each(ranges.begin(), ranges.end(), [&](cinn::ir::Var& v) {
+          CastVarWithBound(v);
+        });
+        auto bf = range->buffer.as_buffer_ref();
+        CastBufferMeta(bf);
+      }
+    }
+
+    for (auto& buffer_range : stmt->write_buffers()) {
+      if (auto range = buffer_range.As<ir::_BufferRange_>()) {
+        std::vector<Var> ranges = range->ranges;
+
+        std::for_each(ranges.begin(), ranges.end(), [&](cinn::ir::Var& v) {
+          CastVarWithBound(v);
+        });
+        auto bf = range->buffer.as_buffer_ref();
+        CastBufferMeta(bf);
+      }
+    }
+    ir::stmt::StmtMutator<>::VisitBlock(stmt->body());
+  }
+  void VisitStmt(ir::stmt::Let stmt) override {
+    Expr body = stmt->body();
+    ir::IRMutator<>::Visit(&body, &body);
+  }
+  void VisitStmt(ir::stmt::Evaluate stmt) override {
+    Expr value = stmt->value();
+    ir::IRMutator<>::Visit(&value, &value);
+  }
+
+  void VisitStmt(ir::stmt::Alloc stmt) override { return; }
+  void VisitStmt(ir::stmt::Free stmt) override { return; }
+
+  void CastVarWithBound(cinn::ir::Var& var) {  // NOLINT
+    if (!var.defined()) return;
+    var->convert_int64_to_int32();
+    auto lb = var->lower_bound;
+    auto ub = var->upper_bound;
+    if (lb.defined()) lb->convert_int64_to_int32();
+    if (ub.defined()) ub->convert_int64_to_int32();
+  }
+  void CastBufferMeta(cinn::ir::Buffer& bf) {  // NOLINT
+    if (!bf.defined()) return;
+    std::for_each(bf->shape.begin(), bf->shape.end(), [&](cinn::ir::Expr& e) {
+      e->convert_int64_to_int32();
+    });
+    std::for_each(bf->strides.begin(),
+                  bf->strides.end(),
+                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
+    bf->elem_offset->convert_int64_to_int32();
+  }
+};
+}  // namespace
+
+LogicalResult LongLong2IntPass::Run(ir::stmt::StmtRef stmt) {
+  CastLonglong2Int narrow;
+  narrow(stmt);
+  return LogicalResult::success();
+}
+
+std::unique_ptr<StmtPass> CreateLongLong2IntPass() {
+  return std::make_unique<LongLong2IntPass>();
+}
+
+bool CanApplyLongLong2Int(ir::stmt::BlockRef block) {
+  CheckOverflow check_overflow;
+  return !check_overflow(block);
+}
+}  // namespace optim
+}  // namespace cinn
diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h
new file mode 100644
index 00000000000000..3441bf8fdf434e
--- /dev/null
+++ b/paddle/cinn/optim/longlong2int_pass.h
@@ -0,0 +1,104 @@
+// Copyright (c) 2024 CINN Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include "paddle/cinn/ir/stmt.h"
+#include "paddle/cinn/pass/pass.h"
+
+namespace cinn {
+namespace optim {
+class LongLong2IntPass : public StmtPass {
+ public:
+  LongLong2IntPass() : StmtPass("longlong2int") {}
+  LogicalResult Run(ir::stmt::StmtRef stmt) override;
+};
+
+/**
+ * Converts int64 (long long) types to int32 in a block where possible.
+ *
+ * IMPORTANT: Before applying this pass, it is MANDATORY to use
+ * `CanApplyLongLong2Int` to check for potential overflow issues.
+ *
+ * This pass is applicable in scenarios where the IR contains int64 types that
+ * can be safely represented as int32 without overflow.
+ *
+ * When applied, this pass will traverse the IR and convert int64 types to int32
+ * in various constructs, including:
+ * - Tensor shapes and indices
+ * - Loop variables and bounds
+ * - Buffer metadata (shapes, strides, offsets)
+ * - Comparison operations
+ *
+ * Overflow checking:
+ * The pass performs overflow checking primarily for nested for-loops. This
+ * focus on nested loops is based on the assumption that they are the most
+ * common source of potential overflows in typical computational kernels. The
+ * check considers:
+ * - The product of loop extents (iteration counts)
+ * - Whether loop bounds are constant and of index type
+ *
+ *
+ * Examples:
+ * 1. Loop variable conversion:
+ * Before conversion:
+ * {
+ *   ScheduleBlock(root_12)
+ *   {
+ *     attrs(tile_method:TileFirstGeneralTactic)
+ *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
+ *     {
+ *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
+ *       {
+ *         ScheduleBlock(var_2)
+ *         {
+ *           i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx %
+ * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)])
+ *           write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll),
+ * i3(0:16ll)])
+ *         var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll]
+ *         }
+ *       }
+ *     }
+ *   }
+ * }
+ *
+ * After conversion:
+ * {
+ *   ScheduleBlock(root_12)
+ *   {
+ *     attrs(tile_method:TileFirstGeneralTactic)
+ *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
+ *     {
+ *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
+ *       {
+ *         ScheduleBlock(var_2)
+ *         {
+ *           i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx %
+ * 256) / 16, idx % 16) read_buffers(_var[i0(0:22), i2(0:16)])
+ *           write_buffers(_var_2[i0(0:22), i1(0:16), i2(0:16),i3(0:16)])
+ *           var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16]
+ *         }
+ *       }
+ *     }
+ *   }
+ * }
+ */
+std::unique_ptr<StmtPass> CreateLongLong2IntPass();
+
+// Check if the given block can be converted from long long to int,
+// A.K.A. the product of the extents of all possible nested loops is within
+// INT_MAX
+bool CanApplyLongLong2Int(ir::stmt::BlockRef block);
+}  // namespace optim
+}  // namespace cinn
diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc
index 10610ed0fd0361..c949d201143c02 100644
--- a/paddle/cinn/optim/transform_gpu_forloop.cc
+++ b/paddle/cinn/optim/transform_gpu_forloop.cc
@@ -28,12 +28,14 @@
 #include "paddle/cinn/ir/ir_mutator.h"
 #include "paddle/cinn/ir/ir_printer.h"
 #include "paddle/cinn/ir/utils/ir_copy.h"
+#include "paddle/cinn/ir/utils/stmt_converter.h"
 #include "paddle/cinn/optim/eliminate_common_factor_of_local_index.h"
 #include "paddle/cinn/optim/ir_simplify.h"
-#include "paddle/cinn/optim/longlong2int.h"
+#include "paddle/cinn/optim/longlong2int_pass.h"
 #include "paddle/cinn/optim/replace_var_with_expr.h"
 #include "paddle/cinn/optim/resize_buffer.h"
 #include "paddle/cinn/optim/update_buffer_axis_pass.h"
+#include "paddle/cinn/pass/pass_manager.h"
 #include "paddle/cinn/poly/isl_utils.h"
 #include "paddle/cinn/poly/stage.h"
 #include "paddle/cinn/runtime/intrinsic.h"
@@ -493,7 +495,15 @@ void OptimizeExprGPU(Expr *expr) {
   ResizeBufferToMaxVarRange(expr);
 
   if (FLAGS_cinn_longlong2int) {
-    TryCastLonglong2Int(expr);
+    ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr);
+    if (CanApplyLongLong2Int(block)) {
+      VLOG(10) << "Before LongLong2IntPass: \n" << *expr;
+      StmtPassManager pass_manager;
+      pass_manager.AddPass(CreateLongLong2IntPass());
+      pass_manager.Run(block);
+      *expr = ir::ConvertStmtBlockToExprBlock(block);
+      VLOG(10) << "After LongLong2IntPass: \n" << *expr;
+    }
   }
 
   VLOG(4) << "After Optimize Expr: \n" << *expr;

From d2cbcf6e5b4ec343ba7340169959bc66040082a1 Mon Sep 17 00:00:00 2001
From: ZhouXin <zhou.xin@mail.ustc.edu.cn>
Date: Fri, 27 Dec 2024 04:18:25 +0000
Subject: [PATCH 2/8] Split ll2int to tow passes

---
 paddle/cinn/optim/longlong2int_pass.cc     | 43 +++++-----
 paddle/cinn/optim/longlong2int_pass.h      | 99 ++++++++++++++++++----
 paddle/cinn/optim/transform_gpu_forloop.cc | 10 ++-
 3 files changed, 111 insertions(+), 41 deletions(-)

diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc
index 734878345d46c3..c9d3291be0cff3 100644
--- a/paddle/cinn/optim/longlong2int_pass.cc
+++ b/paddle/cinn/optim/longlong2int_pass.cc
@@ -103,9 +103,9 @@ class CastLonglong2Int : public ir::IRMutator<>,
     std::for_each(node->indices.begin(),
                   node->indices.end(),
                   [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
-
     ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
   }
+
   void Visit(const ir::Select* op, Expr* expr) override {
     auto node = expr->As<ir::Select>();
     auto cond = node->condition;
@@ -123,10 +123,6 @@ class CastLonglong2Int : public ir::IRMutator<>,
     std::for_each(indices.begin(), indices.end(), [&](cinn::ir::Expr& e) {
       e->convert_int64_to_int32();
     });
-    Expr value = stmt->value();
-    Expr tensor = stmt->tensor();
-    ir::IRMutator<>::Visit(&value, &value);
-    ir::IRMutator<>::Visit(&tensor, &tensor);
   }
   void VisitStmt(IfThenElse stmt) override {
     Expr cond = stmt->condition();
@@ -136,18 +132,12 @@ class CastLonglong2Int : public ir::IRMutator<>,
       if (cond->operand(1).is_index())
         cond->operand(1)->convert_int64_to_int32();
     }
-    ir::stmt::StmtMutator<>::VisitBlock(stmt->true_case());
-    if (stmt->false_case().defined()) {
-      ir::stmt::StmtMutator<>::VisitBlock(stmt->false_case());
-    }
   }
   void VisitStmt(For stmt) override {
     ir::Var loop_var = stmt->loop_var();
     CastVarWithBound(loop_var);
-    stmt->set_loop_var(loop_var);
     stmt->min()->convert_int64_to_int32();
     stmt->extent()->convert_int64_to_int32();
-    ir::stmt::StmtMutator<>::VisitBlock(stmt->body());
   }
   void VisitStmt(Schedule stmt) override {
     std::vector<Var> iter_vars = stmt->iter_vars();
@@ -155,6 +145,11 @@ class CastLonglong2Int : public ir::IRMutator<>,
       CastVarWithBound(v);
     });
 
+    std::vector<Expr> iter_values = stmt->iter_values();
+    std::for_each(iter_values.begin(),
+                  iter_values.end(),
+                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
+
     for (auto& buffer_range : stmt->read_buffers()) {
       if (auto range = buffer_range.As<ir::_BufferRange_>()) {
         std::vector<Var> ranges = range->ranges;
@@ -179,14 +174,8 @@ class CastLonglong2Int : public ir::IRMutator<>,
     }
     ir::stmt::StmtMutator<>::VisitBlock(stmt->body());
   }
-  void VisitStmt(ir::stmt::Let stmt) override {
-    Expr body = stmt->body();
-    ir::IRMutator<>::Visit(&body, &body);
-  }
-  void VisitStmt(ir::stmt::Evaluate stmt) override {
-    Expr value = stmt->value();
-    ir::IRMutator<>::Visit(&value, &value);
-  }
+  void VisitStmt(ir::stmt::Let stmt) override { return; }
+  void VisitStmt(ir::stmt::Evaluate stmt) override { return; }
 
   void VisitStmt(ir::stmt::Alloc stmt) override { return; }
   void VisitStmt(ir::stmt::Free stmt) override { return; }
@@ -212,19 +201,29 @@ class CastLonglong2Int : public ir::IRMutator<>,
 };
 }  // namespace
 
-LogicalResult LongLong2IntPass::Run(ir::stmt::StmtRef stmt) {
+LogicalResult LongLong2IntStmtPass::Run(ir::stmt::StmtRef stmt) {
   CastLonglong2Int narrow;
   narrow(stmt);
   return LogicalResult::success();
 }
 
-std::unique_ptr<StmtPass> CreateLongLong2IntPass() {
-  return std::make_unique<LongLong2IntPass>();
+LogicalResult LongLong2IntExprPass::Run(ir::Expr expr) {
+  CastLonglong2Int narrow;
+  narrow(&expr);
+  return LogicalResult::success();
+}
+std::unique_ptr<StmtPass> CreateLongLong2IntStmtPass() {
+  return std::make_unique<LongLong2IntStmtPass>();
+}
+
+std::unique_ptr<ExprPass> CreateLongLong2IntExprPass() {
+  return std::make_unique<LongLong2IntExprPass>();
 }
 
 bool CanApplyLongLong2Int(ir::stmt::BlockRef block) {
   CheckOverflow check_overflow;
   return !check_overflow(block);
 }
+
 }  // namespace optim
 }  // namespace cinn
diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h
index 3441bf8fdf434e..d912ebea8ab51a 100644
--- a/paddle/cinn/optim/longlong2int_pass.h
+++ b/paddle/cinn/optim/longlong2int_pass.h
@@ -18,14 +18,20 @@
 
 namespace cinn {
 namespace optim {
-class LongLong2IntPass : public StmtPass {
+class LongLong2IntStmtPass : public StmtPass {
  public:
-  LongLong2IntPass() : StmtPass("longlong2int") {}
+  LongLong2IntStmtPass() : StmtPass("longlong2int_stmt") {}
   LogicalResult Run(ir::stmt::StmtRef stmt) override;
 };
 
+class LongLong2IntExprPass : public ExprPass {
+ public:
+  LongLong2IntExprPass() : ExprPass("longlong2int_expr") {}
+  LogicalResult Run(ir::Expr expr) override;
+};
+
 /**
- * Converts int64 (long long) types to int32 in a block where possible.
+ * Converts int64 (long long) types to int32 in a Stmt where possible.
  *
  * IMPORTANT: Before applying this pass, it is MANDATORY to use
  * `CanApplyLongLong2Int` to check for potential overflow issues.
@@ -33,21 +39,12 @@ class LongLong2IntPass : public StmtPass {
  * This pass is applicable in scenarios where the IR contains int64 types that
  * can be safely represented as int32 without overflow.
  *
- * When applied, this pass will traverse the IR and convert int64 types to int32
+ * When applied, this pass will convert int64 expression to int32
  * in various constructs, including:
  * - Tensor shapes and indices
  * - Loop variables and bounds
  * - Buffer metadata (shapes, strides, offsets)
- * - Comparison operations
- *
- * Overflow checking:
- * The pass performs overflow checking primarily for nested for-loops. This
- * focus on nested loops is based on the assumption that they are the most
- * common source of potential overflows in typical computational kernels. The
- * check considers:
- * - The product of loop extents (iteration counts)
- * - Whether loop bounds are constant and of index type
- *
+ * - Comparison operations (index only)
  *
  * Examples:
  * 1. Loop variable conversion:
@@ -87,14 +84,84 @@ class LongLong2IntPass : public StmtPass {
  *           i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx %
  * 256) / 16, idx % 16) read_buffers(_var[i0(0:22), i2(0:16)])
  *           write_buffers(_var_2[i0(0:22), i1(0:16), i2(0:16),i3(0:16)])
- *           var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16]
+ *           var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll]
+ *         }
+ *       }
+ *     }
+ *   }
+ * }
+ *
+ * The 16ll in var[i0, i2, i3 + i1 * 16ll] is not converted for it is part of
+ * Load Exoression, which will be converted in LongLong2IntExprPass.
+ */
+std::unique_ptr<StmtPass> CreateLongLong2IntStmtPass();
+
+/**
+ * Converts int64 (long long) types to int32 in a Stmt where possible.
+ *
+ * IMPORTANT: Before applying this pass, it is MANDATORY to use
+ * `CanApplyLongLong2Int` to check for potential overflow issues.
+ *
+ * This pass is applicable in scenarios where the IR contains int64 types that
+ * can be safely represented as int32 without overflow.
+ *
+ * When applied, this pass will convert int64 expression to int32
+ * in various constructs, including:
+ * - Tensor shapes and indices
+ * - Loop variables and bounds
+ * - Buffer metadata (shapes, strides, offsets)
+ * - Comparison operations (index only)
+ *
+ * Examples:
+ * 1. Loop variable conversion:
+ * Before conversion:
+ * {
+ *   ScheduleBlock(root_12)
+ *   {
+ *     attrs(tile_method:TileFirstGeneralTactic)
+ *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
+ *     {
+ *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
+ *       {
+ *         ScheduleBlock(var_2)
+ *         {
+ *           i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx %
+ * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)])
+ *           write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll),
+ * i3(0:16ll)])
+ *         var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll]
  *         }
  *       }
  *     }
  *   }
  * }
+ *
+ * After conversion:
+ * {
+ *   ScheduleBlock(root_12)
+ *   {
+ *     attrs(tile_method:TileFirstGeneralTactic)
+ *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
+ *     {
+ *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
+ *       {
+ *         ScheduleBlock(var_2)
+ *         {
+ *           i0, i1, i2, i3 = axis.bind(idx / 4096ll, (idx % 4096ll) / 256ll,
+ * (idx % 256ll) / 16ll, idx % 16ll) read_buffers(_var[i0(0:22ll), i2(0:16ll)])
+ *           write_buffers(_var_2[i0(0:22ll), i1(0:16ll),
+ * i2(0:16ll),i3(0:16ll)]) var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16]
+ *         }
+ *       }
+ *     }
+ *   }
+ * }
+ *
+ * Only 16ll in var[i0, i2, i3 + i1 * 16ll] is converted for other longlong
+ * Exprs are components of ScheduleBlock, which will be converted in
+ * LongLong2IntStmtPass.
  */
-std::unique_ptr<StmtPass> CreateLongLong2IntPass();
+std::unique_ptr<ExprPass> CreateLongLong2IntExprPass();
 
 // Check if the given block can be converted from long long to int,
 // A.K.A. the product of the extents of all possible nested loops is within
diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc
index c949d201143c02..5139b64dec2861 100644
--- a/paddle/cinn/optim/transform_gpu_forloop.cc
+++ b/paddle/cinn/optim/transform_gpu_forloop.cc
@@ -497,12 +497,16 @@ void OptimizeExprGPU(Expr *expr) {
   if (FLAGS_cinn_longlong2int) {
     ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr);
     if (CanApplyLongLong2Int(block)) {
-      VLOG(10) << "Before LongLong2IntPass: \n" << *expr;
+      VLOG(10) << "Before LongLong2IntStmtPass: \n" << *expr;
       StmtPassManager pass_manager;
-      pass_manager.AddPass(CreateLongLong2IntPass());
+      pass_manager.AddPass(CreateLongLong2IntStmtPass());
       pass_manager.Run(block);
+      VLOG(10) << "After LongLong2IntStmtPass: \n" << block;
+      ExprPassManager expr_pass_manager;
+      expr_pass_manager.AddPass(CreateLongLong2IntExprPass());
+      expr_pass_manager.Run(block);
+      VLOG(10) << "After LongLong2IntExprPass: \n" << block;
       *expr = ir::ConvertStmtBlockToExprBlock(block);
-      VLOG(10) << "After LongLong2IntPass: \n" << *expr;
     }
   }
 

From 73e89823e709cbaf2b833b83b58e05d391143119 Mon Sep 17 00:00:00 2001
From: ZhouXin <zhou.xin@mail.ustc.edu.cn>
Date: Wed, 25 Dec 2024 06:49:45 +0000
Subject: [PATCH 3/8] apply cherry pick


From 95def9f5470690b7283b9b38b0e0eac07f088ac4 Mon Sep 17 00:00:00 2001
From: ZhouXin <zhou.xin@mail.ustc.edu.cn>
Date: Fri, 27 Dec 2024 07:28:53 +0000
Subject: [PATCH 4/8] Extract stmt logic from mutator into StmtPass

---
 paddle/cinn/optim/longlong2int_pass.cc | 128 ++++++++++++++-----------
 1 file changed, 71 insertions(+), 57 deletions(-)

diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc
index c9d3291be0cff3..08b2f199943f8a 100644
--- a/paddle/cinn/optim/longlong2int_pass.cc
+++ b/paddle/cinn/optim/longlong2int_pass.cc
@@ -30,6 +30,25 @@ using ir::stmt::Schedule;
 using ir::stmt::StmtRef;
 using ir::stmt::Store;
 
+void CastVarWithBound(cinn::ir::Var& var) {  // NOLINT
+  if (!var.defined()) return;
+  var->convert_int64_to_int32();
+  auto lb = var->lower_bound;
+  auto ub = var->upper_bound;
+  if (lb.defined()) lb->convert_int64_to_int32();
+  if (ub.defined()) ub->convert_int64_to_int32();
+}
+void CastBufferMeta(cinn::ir::Buffer& bf) {  // NOLINT
+  if (!bf.defined()) return;
+  std::for_each(bf->shape.begin(), bf->shape.end(), [&](cinn::ir::Expr& e) {
+    e->convert_int64_to_int32();
+  });
+  std::for_each(bf->strides.begin(), bf->strides.end(), [&](cinn::ir::Expr& e) {
+    e->convert_int64_to_int32();
+  });
+  bf->elem_offset->convert_int64_to_int32();
+}
+
 class CheckOverflow : public ir::stmt::StmtVisitor<> {
  public:
   bool operator()(const StmtRef& stmt) {
@@ -81,14 +100,9 @@ class CheckOverflow : public ir::stmt::StmtVisitor<> {
   bool is_overflow_ = false;
 };
 
-class CastLonglong2Int : public ir::IRMutator<>,
-                         public ir::stmt::StmtMutator<> {
+class CastLonglong2IntMutator : public ir::IRMutator<> {
  public:
   void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
-  void operator()(StmtRef stmt) { ir::stmt::StmtMutator<>::VisitStmt(stmt); }
-  void operator()(BlockRef block) {
-    ir::stmt::StmtMutator<>::VisitBlock(block);
-  }
 
  private:
   void Visit(const ir::_Tensor_* op, Expr* expr) override {
@@ -118,39 +132,50 @@ class CastLonglong2Int : public ir::IRMutator<>,
     ir::IRMutator<>::Visit(&node->true_value, &node->true_value);
     ir::IRMutator<>::Visit(&node->false_value, &node->false_value);
   }
-  void VisitStmt(Store stmt) override {
-    std::vector<Expr> indices = stmt->indices();
-    std::for_each(indices.begin(), indices.end(), [&](cinn::ir::Expr& e) {
-      e->convert_int64_to_int32();
-    });
-  }
-  void VisitStmt(IfThenElse stmt) override {
-    Expr cond = stmt->condition();
+};
+
+}  // namespace
+
+LogicalResult LongLong2IntStmtPass::Run(ir::stmt::StmtRef stmt) {
+  auto CastStore = [](StmtRef stmt) {
+    Store store_stmt = stmt.as<Store>();
+    for (Expr index : store_stmt->indices()) {
+      index->convert_int64_to_int32();
+    }
+  };
+
+  auto CastIfThenElse = [](StmtRef stmt) {
+    IfThenElse if_stmt = stmt.as<IfThenElse>();
+    Expr cond = if_stmt->condition();
     if (cond.is_cmp()) {
       if (cond->operand(0).is_index())
         cond->operand(0)->convert_int64_to_int32();
       if (cond->operand(1).is_index())
         cond->operand(1)->convert_int64_to_int32();
     }
-  }
-  void VisitStmt(For stmt) override {
-    ir::Var loop_var = stmt->loop_var();
+  };
+
+  auto CastFor = [](StmtRef stmt) {
+    For for_stmt = stmt.as<For>();
+    ir::Var loop_var = for_stmt->loop_var();
     CastVarWithBound(loop_var);
-    stmt->min()->convert_int64_to_int32();
-    stmt->extent()->convert_int64_to_int32();
-  }
-  void VisitStmt(Schedule stmt) override {
-    std::vector<Var> iter_vars = stmt->iter_vars();
+    for_stmt->min()->convert_int64_to_int32();
+    for_stmt->extent()->convert_int64_to_int32();
+  };
+
+  auto CastSchedule = [](StmtRef stmt) {
+    Schedule schedule_stmt = stmt.as<Schedule>();
+    std::vector<Var> iter_vars = schedule_stmt->iter_vars();
     std::for_each(iter_vars.begin(), iter_vars.end(), [&](cinn::ir::Var& v) {
       CastVarWithBound(v);
     });
 
-    std::vector<Expr> iter_values = stmt->iter_values();
+    std::vector<Expr> iter_values = schedule_stmt->iter_values();
     std::for_each(iter_values.begin(),
                   iter_values.end(),
                   [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
 
-    for (auto& buffer_range : stmt->read_buffers()) {
+    for (auto& buffer_range : schedule_stmt->read_buffers()) {
       if (auto range = buffer_range.As<ir::_BufferRange_>()) {
         std::vector<Var> ranges = range->ranges;
         std::for_each(ranges.begin(), ranges.end(), [&](cinn::ir::Var& v) {
@@ -161,7 +186,7 @@ class CastLonglong2Int : public ir::IRMutator<>,
       }
     }
 
-    for (auto& buffer_range : stmt->write_buffers()) {
+    for (auto& buffer_range : schedule_stmt->write_buffers()) {
       if (auto range = buffer_range.As<ir::_BufferRange_>()) {
         std::vector<Var> ranges = range->ranges;
 
@@ -172,43 +197,32 @@ class CastLonglong2Int : public ir::IRMutator<>,
         CastBufferMeta(bf);
       }
     }
-    ir::stmt::StmtMutator<>::VisitBlock(stmt->body());
-  }
-  void VisitStmt(ir::stmt::Let stmt) override { return; }
-  void VisitStmt(ir::stmt::Evaluate stmt) override { return; }
-
-  void VisitStmt(ir::stmt::Alloc stmt) override { return; }
-  void VisitStmt(ir::stmt::Free stmt) override { return; }
-
-  void CastVarWithBound(cinn::ir::Var& var) {  // NOLINT
-    if (!var.defined()) return;
-    var->convert_int64_to_int32();
-    auto lb = var->lower_bound;
-    auto ub = var->upper_bound;
-    if (lb.defined()) lb->convert_int64_to_int32();
-    if (ub.defined()) ub->convert_int64_to_int32();
-  }
-  void CastBufferMeta(cinn::ir::Buffer& bf) {  // NOLINT
-    if (!bf.defined()) return;
-    std::for_each(bf->shape.begin(), bf->shape.end(), [&](cinn::ir::Expr& e) {
-      e->convert_int64_to_int32();
-    });
-    std::for_each(bf->strides.begin(),
-                  bf->strides.end(),
-                  [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
-    bf->elem_offset->convert_int64_to_int32();
-  }
-};
-}  // namespace
+  };
 
-LogicalResult LongLong2IntStmtPass::Run(ir::stmt::StmtRef stmt) {
-  CastLonglong2Int narrow;
-  narrow(stmt);
+  switch (stmt->stmt_type()) {
+    case ir::StmtNodeTy::Store:
+      CastStore(stmt);
+      break;
+
+    case ir::StmtNodeTy::IfThenElse:
+      CastIfThenElse(stmt);
+      break;
+
+    case ir::StmtNodeTy::For:
+      CastFor(stmt);
+      break;
+
+    case ir::StmtNodeTy::Schedule:
+      CastSchedule(stmt);
+      break;
+    default:
+      break;
+  }
   return LogicalResult::success();
 }
 
 LogicalResult LongLong2IntExprPass::Run(ir::Expr expr) {
-  CastLonglong2Int narrow;
+  CastLonglong2IntMutator narrow;
   narrow(&expr);
   return LogicalResult::success();
 }

From baf7ccb7105e5d8eca7b838b0365a2343234e866 Mon Sep 17 00:00:00 2001
From: ZhouXin <zhou.xin@mail.ustc.edu.cn>
Date: Fri, 27 Dec 2024 07:34:16 +0000
Subject: [PATCH 5/8] Refine comment

---
 paddle/cinn/optim/longlong2int_pass.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h
index d912ebea8ab51a..ddb9f14f453d82 100644
--- a/paddle/cinn/optim/longlong2int_pass.h
+++ b/paddle/cinn/optim/longlong2int_pass.h
@@ -97,7 +97,7 @@ class LongLong2IntExprPass : public ExprPass {
 std::unique_ptr<StmtPass> CreateLongLong2IntStmtPass();
 
 /**
- * Converts int64 (long long) types to int32 in a Stmt where possible.
+ * Converts int64 (long long) types to int32 in a Expr where possible.
  *
  * IMPORTANT: Before applying this pass, it is MANDATORY to use
  * `CanApplyLongLong2Int` to check for potential overflow issues.

From 82ed67384253d385d5deee28d03ae48cef7bf335 Mon Sep 17 00:00:00 2001
From: ZhouXin <zhou.xin@mail.ustc.edu.cn>
Date: Mon, 30 Dec 2024 03:23:30 +0000
Subject: [PATCH 6/8] Implement CastLonglong2Int function to convert int64
 types to int32 with overflow checks

---
 paddle/cinn/optim/longlong2int_pass.cc     | 13 ++++
 paddle/cinn/optim/longlong2int_pass.h      | 70 ++++++++++++++++++++++
 paddle/cinn/optim/transform_gpu_forloop.cc | 15 +----
 paddle/cinn/pass/pass_adaptor.h            |  1 +
 4 files changed, 87 insertions(+), 12 deletions(-)

diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc
index 08b2f199943f8a..7d03515ce0990c 100644
--- a/paddle/cinn/optim/longlong2int_pass.cc
+++ b/paddle/cinn/optim/longlong2int_pass.cc
@@ -19,6 +19,7 @@
 #include "paddle/cinn/ir/ir_visitor.h"
 #include "paddle/cinn/ir/stmt.h"
 #include "paddle/cinn/ir/stmt_visitors.h"
+#include "paddle/cinn/pass/pass_manager.h"
 
 namespace cinn {
 namespace optim {
@@ -239,5 +240,17 @@ bool CanApplyLongLong2Int(ir::stmt::BlockRef block) {
   return !check_overflow(block);
 }
 
+void CastLonglong2Int(ir::stmt::BlockRef block) {
+  if (CanApplyLongLong2Int(block)) {
+    StmtPassManager stmt_pass_manager;
+    stmt_pass_manager.AddPass(CreateLongLong2IntStmtPass());
+    ExprPassManager expr_pass_manager;
+    expr_pass_manager.AddPass(CreateLongLong2IntExprPass());
+
+    stmt_pass_manager.Run(block);
+    expr_pass_manager.Run(block);
+  }
+}
+
 }  // namespace optim
 }  // namespace cinn
diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h
index ddb9f14f453d82..f6c99b3df66fc9 100644
--- a/paddle/cinn/optim/longlong2int_pass.h
+++ b/paddle/cinn/optim/longlong2int_pass.h
@@ -167,5 +167,75 @@ std::unique_ptr<ExprPass> CreateLongLong2IntExprPass();
 // A.K.A. the product of the extents of all possible nested loops is within
 // INT_MAX
 bool CanApplyLongLong2Int(ir::stmt::BlockRef block);
+
+/**
+ * Converts int64 (long long) types to int32 in a block where possible.
+ *
+ * This pass is applicable in scenarios where the IR contains int64 types that
+ * can be safely represented as int32 without overflow.
+ *
+ * When applied, this pass will traverse the IR and convert int64 types to int32
+ * in various constructs, including:
+ * - Tensor shapes and indices
+ * - Loop variables and bounds
+ * - Buffer metadata (shapes, strides, offsets)
+ * - Comparison operations
+ *
+ * Overflow checking:
+ * The pass performs overflow checking primarily for nested for-loops. This
+ * focus on nested loops is based on the assumption that they are the most
+ * common source of potential overflows in typical computational kernels. The
+ * check considers:
+ * - The product of loop extents (iteration counts)
+ * - Whether loop bounds are constant and of index type
+ *
+ *
+ * Examples:
+ * 1. Loop variable conversion:
+ * Before conversion:
+ * {
+ *   ScheduleBlock(root_12)
+ *   {
+ *     attrs(tile_method:TileFirstGeneralTactic)
+ *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
+ *     {
+ *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
+ *       {
+ *         ScheduleBlock(var_2)
+ *         {
+ *           i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx %
+ * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)])
+ *           write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll),
+ * i3(0:16ll)])
+ *         var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll]
+ *         }
+ *       }
+ *     }
+ *   }
+ * }
+ *
+ * After conversion:
+ * {
+ *   ScheduleBlock(root_12)
+ *   {
+ *     attrs(tile_method:TileFirstGeneralTactic)
+ *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
+ *     {
+ *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
+ *       {
+ *         ScheduleBlock(var_2)
+ *         {
+ *           i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx %
+ * 256) / 16, idx % 16) read_buffers(_var[i0(0:22), i2(0:16)])
+ *           write_buffers(_var_2[i0(0:22), i1(0:16), i2(0:16),i3(0:16)])
+ *           var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16]
+ *         }
+ *       }
+ *     }
+ *   }
+ * }
+ */
+void CastLonglong2Int(ir::stmt::BlockRef block);
+
 }  // namespace optim
 }  // namespace cinn
diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc
index 5139b64dec2861..cb7be4e49d34d6 100644
--- a/paddle/cinn/optim/transform_gpu_forloop.cc
+++ b/paddle/cinn/optim/transform_gpu_forloop.cc
@@ -496,18 +496,9 @@ void OptimizeExprGPU(Expr *expr) {
 
   if (FLAGS_cinn_longlong2int) {
     ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr);
-    if (CanApplyLongLong2Int(block)) {
-      VLOG(10) << "Before LongLong2IntStmtPass: \n" << *expr;
-      StmtPassManager pass_manager;
-      pass_manager.AddPass(CreateLongLong2IntStmtPass());
-      pass_manager.Run(block);
-      VLOG(10) << "After LongLong2IntStmtPass: \n" << block;
-      ExprPassManager expr_pass_manager;
-      expr_pass_manager.AddPass(CreateLongLong2IntExprPass());
-      expr_pass_manager.Run(block);
-      VLOG(10) << "After LongLong2IntExprPass: \n" << block;
-      *expr = ir::ConvertStmtBlockToExprBlock(block);
-    }
+    VLOG(10) << "Before CastLonglong2Int: \n" << block;
+    CastLonglong2Int(block);
+    VLOG(10) << "After CastLonglong2Int: \n" << block;
   }
 
   VLOG(4) << "After Optimize Expr: \n" << *expr;
diff --git a/paddle/cinn/pass/pass_adaptor.h b/paddle/cinn/pass/pass_adaptor.h
index 593660254eb3ce..19275f2875f222 100644
--- a/paddle/cinn/pass/pass_adaptor.h
+++ b/paddle/cinn/pass/pass_adaptor.h
@@ -14,6 +14,7 @@
 
 #pragma once
 
+#include "paddle/cinn/ir/utils/stmt_converter.h"
 #include "paddle/cinn/pass/pass.h"
 
 namespace cinn {

From ca92196c73a2034824006e272959d60c84d3016e Mon Sep 17 00:00:00 2001
From: ZhouXin <zhou.xin@mail.ustc.edu.cn>
Date: Mon, 30 Dec 2024 04:59:18 +0000
Subject: [PATCH 7/8] Refine

---
 paddle/cinn/optim/longlong2int_pass.cc |  14 +++
 paddle/cinn/optim/longlong2int_pass.h  | 149 -------------------------
 2 files changed, 14 insertions(+), 149 deletions(-)

diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc
index 7d03515ce0990c..e03649353c5579 100644
--- a/paddle/cinn/optim/longlong2int_pass.cc
+++ b/paddle/cinn/optim/longlong2int_pass.cc
@@ -135,6 +135,17 @@ class CastLonglong2IntMutator : public ir::IRMutator<> {
   }
 };
 
+class LongLong2IntStmtPass : public StmtPass {
+ public:
+  LongLong2IntStmtPass() : StmtPass("longlong2int_stmt") {}
+  LogicalResult Run(ir::stmt::StmtRef stmt) override;
+};
+
+class LongLong2IntExprPass : public ExprPass {
+ public:
+  LongLong2IntExprPass() : ExprPass("longlong2int_expr") {}
+  LogicalResult Run(ir::Expr expr) override;
+};
 }  // namespace
 
 LogicalResult LongLong2IntStmtPass::Run(ir::stmt::StmtRef stmt) {
@@ -235,6 +246,9 @@ std::unique_ptr<ExprPass> CreateLongLong2IntExprPass() {
   return std::make_unique<LongLong2IntExprPass>();
 }
 
+// Check if the given block can be converted from long long to int,
+// A.K.A. the product of the extents of all possible nested loops is within
+// INT_MAX
 bool CanApplyLongLong2Int(ir::stmt::BlockRef block) {
   CheckOverflow check_overflow;
   return !check_overflow(block);
diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h
index f6c99b3df66fc9..d0c35e69d6735d 100644
--- a/paddle/cinn/optim/longlong2int_pass.h
+++ b/paddle/cinn/optim/longlong2int_pass.h
@@ -18,155 +18,6 @@
 
 namespace cinn {
 namespace optim {
-class LongLong2IntStmtPass : public StmtPass {
- public:
-  LongLong2IntStmtPass() : StmtPass("longlong2int_stmt") {}
-  LogicalResult Run(ir::stmt::StmtRef stmt) override;
-};
-
-class LongLong2IntExprPass : public ExprPass {
- public:
-  LongLong2IntExprPass() : ExprPass("longlong2int_expr") {}
-  LogicalResult Run(ir::Expr expr) override;
-};
-
-/**
- * Converts int64 (long long) types to int32 in a Stmt where possible.
- *
- * IMPORTANT: Before applying this pass, it is MANDATORY to use
- * `CanApplyLongLong2Int` to check for potential overflow issues.
- *
- * This pass is applicable in scenarios where the IR contains int64 types that
- * can be safely represented as int32 without overflow.
- *
- * When applied, this pass will convert int64 expression to int32
- * in various constructs, including:
- * - Tensor shapes and indices
- * - Loop variables and bounds
- * - Buffer metadata (shapes, strides, offsets)
- * - Comparison operations (index only)
- *
- * Examples:
- * 1. Loop variable conversion:
- * Before conversion:
- * {
- *   ScheduleBlock(root_12)
- *   {
- *     attrs(tile_method:TileFirstGeneralTactic)
- *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
- *     {
- *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
- *       {
- *         ScheduleBlock(var_2)
- *         {
- *           i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx %
- * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)])
- *           write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll),
- * i3(0:16ll)])
- *         var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll]
- *         }
- *       }
- *     }
- *   }
- * }
- *
- * After conversion:
- * {
- *   ScheduleBlock(root_12)
- *   {
- *     attrs(tile_method:TileFirstGeneralTactic)
- *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
- *     {
- *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
- *       {
- *         ScheduleBlock(var_2)
- *         {
- *           i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx %
- * 256) / 16, idx % 16) read_buffers(_var[i0(0:22), i2(0:16)])
- *           write_buffers(_var_2[i0(0:22), i1(0:16), i2(0:16),i3(0:16)])
- *           var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll]
- *         }
- *       }
- *     }
- *   }
- * }
- *
- * The 16ll in var[i0, i2, i3 + i1 * 16ll] is not converted for it is part of
- * Load Exoression, which will be converted in LongLong2IntExprPass.
- */
-std::unique_ptr<StmtPass> CreateLongLong2IntStmtPass();
-
-/**
- * Converts int64 (long long) types to int32 in a Expr where possible.
- *
- * IMPORTANT: Before applying this pass, it is MANDATORY to use
- * `CanApplyLongLong2Int` to check for potential overflow issues.
- *
- * This pass is applicable in scenarios where the IR contains int64 types that
- * can be safely represented as int32 without overflow.
- *
- * When applied, this pass will convert int64 expression to int32
- * in various constructs, including:
- * - Tensor shapes and indices
- * - Loop variables and bounds
- * - Buffer metadata (shapes, strides, offsets)
- * - Comparison operations (index only)
- *
- * Examples:
- * 1. Loop variable conversion:
- * Before conversion:
- * {
- *   ScheduleBlock(root_12)
- *   {
- *     attrs(tile_method:TileFirstGeneralTactic)
- *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
- *     {
- *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
- *       {
- *         ScheduleBlock(var_2)
- *         {
- *           i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx %
- * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)])
- *           write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll),
- * i3(0:16ll)])
- *         var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll]
- *         }
- *       }
- *     }
- *   }
- * }
- *
- * After conversion:
- * {
- *   ScheduleBlock(root_12)
- *   {
- *     attrs(tile_method:TileFirstGeneralTactic)
- *     thread_bind[blockIdx.x] for (blockIdx.x, 0, 352)
- *     {
- *       thread_bind[threadIdx.x] for (threadIdx.x, 0, 256)
- *       {
- *         ScheduleBlock(var_2)
- *         {
- *           i0, i1, i2, i3 = axis.bind(idx / 4096ll, (idx % 4096ll) / 256ll,
- * (idx % 256ll) / 16ll, idx % 16ll) read_buffers(_var[i0(0:22ll), i2(0:16ll)])
- *           write_buffers(_var_2[i0(0:22ll), i1(0:16ll),
- * i2(0:16ll),i3(0:16ll)]) var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16]
- *         }
- *       }
- *     }
- *   }
- * }
- *
- * Only 16ll in var[i0, i2, i3 + i1 * 16ll] is converted for other longlong
- * Exprs are components of ScheduleBlock, which will be converted in
- * LongLong2IntStmtPass.
- */
-std::unique_ptr<ExprPass> CreateLongLong2IntExprPass();
-
-// Check if the given block can be converted from long long to int,
-// A.K.A. the product of the extents of all possible nested loops is within
-// INT_MAX
-bool CanApplyLongLong2Int(ir::stmt::BlockRef block);
 
 /**
  * Converts int64 (long long) types to int32 in a block where possible.

From b5d358772b930e6f69294b0cd0617468a7c064bc Mon Sep 17 00:00:00 2001
From: ZhouXin <zhou.xin@mail.ustc.edu.cn>
Date: Mon, 30 Dec 2024 06:05:34 +0000
Subject: [PATCH 8/8] Rename CastLonglong2Int to TryCastLonglong2Int for
 clarity and update references

---
 paddle/cinn/optim/longlong2int_pass.cc     | 2 +-
 paddle/cinn/optim/longlong2int_pass.h      | 2 +-
 paddle/cinn/optim/transform_gpu_forloop.cc | 3 ++-
 3 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc
index e03649353c5579..306b880b57c88e 100644
--- a/paddle/cinn/optim/longlong2int_pass.cc
+++ b/paddle/cinn/optim/longlong2int_pass.cc
@@ -254,7 +254,7 @@ bool CanApplyLongLong2Int(ir::stmt::BlockRef block) {
   return !check_overflow(block);
 }
 
-void CastLonglong2Int(ir::stmt::BlockRef block) {
+void TryCastLonglong2Int(ir::stmt::BlockRef block) {
   if (CanApplyLongLong2Int(block)) {
     StmtPassManager stmt_pass_manager;
     stmt_pass_manager.AddPass(CreateLongLong2IntStmtPass());
diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h
index d0c35e69d6735d..fa6ba61ad8b6f3 100644
--- a/paddle/cinn/optim/longlong2int_pass.h
+++ b/paddle/cinn/optim/longlong2int_pass.h
@@ -86,7 +86,7 @@ namespace optim {
  *   }
  * }
  */
-void CastLonglong2Int(ir::stmt::BlockRef block);
+void TryCastLonglong2Int(ir::stmt::BlockRef block);
 
 }  // namespace optim
 }  // namespace cinn
diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc
index cb7be4e49d34d6..82eac4839c48e1 100644
--- a/paddle/cinn/optim/transform_gpu_forloop.cc
+++ b/paddle/cinn/optim/transform_gpu_forloop.cc
@@ -497,8 +497,9 @@ void OptimizeExprGPU(Expr *expr) {
   if (FLAGS_cinn_longlong2int) {
     ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr);
     VLOG(10) << "Before CastLonglong2Int: \n" << block;
-    CastLonglong2Int(block);
+    TryCastLonglong2Int(block);
     VLOG(10) << "After CastLonglong2Int: \n" << block;
+    *expr = ir::ConvertStmtBlockToExprBlock(block);
   }
 
   VLOG(4) << "After Optimize Expr: \n" << *expr;