From 0a1aac68eb132755ea27d7e066c6a3f3c1287e90 Mon Sep 17 00:00:00 2001 From: ZhouXin <zhou.xin@mail.ustc.edu.cn> Date: Wed, 25 Dec 2024 08:29:25 +0000 Subject: [PATCH 1/8] Update longlong2int pass --- paddle/cinn/optim/CMakeLists.txt | 2 +- paddle/cinn/optim/longlong2int.cc | 191 ----------------- paddle/cinn/optim/longlong2int.h | 24 --- paddle/cinn/optim/longlong2int_pass.cc | 230 +++++++++++++++++++++ paddle/cinn/optim/longlong2int_pass.h | 104 ++++++++++ paddle/cinn/optim/transform_gpu_forloop.cc | 14 +- 6 files changed, 347 insertions(+), 218 deletions(-) delete mode 100644 paddle/cinn/optim/longlong2int.cc delete mode 100644 paddle/cinn/optim/longlong2int.h create mode 100644 paddle/cinn/optim/longlong2int_pass.cc create mode 100644 paddle/cinn/optim/longlong2int_pass.h diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index 6d2ae9b159df89..92682e90b79240 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -38,7 +38,7 @@ gather_srcs( eliminate_common_global_memory_read.cc rearrange_load_instruction.cc check_tensor_buffer_map.cc - longlong2int.cc + longlong2int_pass.cc vectorize_for_trans.cc) if(WITH_CUDA OR WITH_ROCM) diff --git a/paddle/cinn/optim/longlong2int.cc b/paddle/cinn/optim/longlong2int.cc deleted file mode 100644 index de158332ac1ef9..00000000000000 --- a/paddle/cinn/optim/longlong2int.cc +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) 2024 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/longlong2int.h" -#include "paddle/cinn/ir/ir_mutator.h" -#include "paddle/cinn/ir/ir_printer.h" -#include "paddle/cinn/ir/ir_utils.h" -#include "paddle/cinn/ir/ir_visitor.h" - -namespace cinn { -namespace optim { - -class CheckOverflow : public ir::IRVisitor { - public: - bool is_overflow(Expr* expr) { - ir::IRVisitor::Visit(expr); - return is_overflow_; - } - - private: - void Visit(const ir::For* for_op) override { - if (!for_op->extent.is_constant()) is_overflow_ = true; - if (!for_op->extent.type().is_index_type()) is_overflow_ = true; - if (curr_product_ > INT_MAX) is_overflow_ = true; - - if (is_overflow_) return; - - curr_product_ *= for_op->extent.as_int64(); - ir::IRVisitor::Visit(&for_op->body); - curr_product_ /= for_op->extent.as_int64(); - } - void Visit(const ir::ScheduleBlock* op) override { - ir::IRVisitor::Visit(&(op->body)); - } - void Visit(const ir::ScheduleBlockRealize* op) override { - ir::IRVisitor::Visit(&(op->schedule_block)); - } - void Visit(const ir::Block* op) { - for (auto& expr : op->stmts) { - ir::IRVisitor::Visit(&expr); - } - } - void Visit(const ir::IfThenElse* op) { - ir::IRVisitor::Visit(&(op->true_case)); - if (op->false_case.defined()) ir::IRVisitor::Visit(&(op->false_case)); - } - int64_t curr_product_ = 1; - bool is_overflow_ = false; -}; - -class CastLonglong2Int : public ir::IRMutator<> { - public: - void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - private: - void Visit(const ir::_Tensor_* op, Expr* expr) override { - auto node = expr->As<ir::_Tensor_>(); - std::for_each(node->shape.begin(), - node->shape.end(), - [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); - CastBufferMeta(node->buffer); - } - void Visit(const ir::Load* op, Expr* expr) override { - auto node = expr->As<ir::Load>(); - std::for_each(node->indices.begin(), - node->indices.end(), - [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); - - ir::IRMutator<>::Visit(&node->tensor, &node->tensor); - } - void Visit(const ir::Store* op, Expr* expr) override { - auto node = expr->As<ir::Store>(); - std::for_each(node->indices.begin(), - node->indices.end(), - [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); - ir::IRMutator<>::Visit(&node->value, &node->value); - ir::IRMutator<>::Visit(&node->tensor, &node->tensor); - } - void Visit(const ir::IfThenElse* op, Expr* expr) override { - auto node = expr->As<ir::IfThenElse>(); - auto cond = node->condition; - if (cond.is_cmp()) { - if (cond->operand(0).is_index()) - cond->operand(0)->convert_int64_to_int32(); - if (cond->operand(1).is_index()) - cond->operand(1)->convert_int64_to_int32(); - } - ir::IRMutator<>::Visit(&node->true_case, &node->true_case); - if (node->false_case.defined()) { - ir::IRMutator<>::Visit(&node->false_case, &node->false_case); - } - } - void Visit(const ir::Select* op, Expr* expr) override { - auto node = expr->As<ir::Select>(); - auto cond = node->condition; - if (cond.is_cmp()) { - if (cond->operand(0).is_index()) - cond->operand(0)->convert_int64_to_int32(); - if (cond->operand(1).is_index()) - cond->operand(1)->convert_int64_to_int32(); - } - ir::IRMutator<>::Visit(&node->true_value, &node->true_value); - ir::IRMutator<>::Visit(&node->false_value, &node->false_value); - } - void Visit(const ir::For* op, Expr* expr) override { - auto node = expr->As<ir::For>(); - CastVarWithBound(node->loop_var); - node->min->convert_int64_to_int32(); - node->extent->convert_int64_to_int32(); - ir::IRMutator<>::Visit(&node->body, &node->body); - } - void Visit(const ir::ScheduleBlock* op, Expr* expr) override { - auto* node = expr->As<ir::ScheduleBlock>(); - - std::for_each(node->iter_vars.begin(), - node->iter_vars.end(), - [&](cinn::ir::Var& v) { CastVarWithBound(v); }); - - for (auto& buffer_range : node->read_buffers) { - if (auto range = buffer_range.As<ir::_BufferRange_>()) { - std::for_each(range->ranges.begin(), - range->ranges.end(), - [&](cinn::ir::Var& v) { CastVarWithBound(v); }); - auto bf = range->buffer.as_buffer_ref(); - CastBufferMeta(bf); - } - } - - for (auto& buffer_range : node->write_buffers) { - if (auto range = buffer_range.As<ir::_BufferRange_>()) { - std::for_each(range->ranges.begin(), - range->ranges.end(), - [&](cinn::ir::Var& v) { CastVarWithBound(v); }); - auto bf = range->buffer.as_buffer_ref(); - CastBufferMeta(bf); - } - } - ir::IRMutator<>::Visit(&(node->body), &(node->body)); - } - - void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override { - auto* node = expr->As<ir::ScheduleBlockRealize>(); - - std::for_each(node->iter_values.begin(), - node->iter_values.end(), - [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); - ir::IRMutator<>::Visit(&node->schedule_block, &node->schedule_block); - } - - void CastVarWithBound(cinn::ir::Var& var) { // NOLINT - if (!var.defined()) return; - var->convert_int64_to_int32(); - auto lb = var->lower_bound; - auto ub = var->upper_bound; - if (lb.defined()) lb->convert_int64_to_int32(); - if (ub.defined()) ub->convert_int64_to_int32(); - } - void CastBufferMeta(cinn::ir::Buffer& bf) { // NOLINT - if (!bf.defined()) return; - std::for_each(bf->shape.begin(), bf->shape.end(), [&](cinn::ir::Expr& e) { - e->convert_int64_to_int32(); - }); - std::for_each(bf->strides.begin(), - bf->strides.end(), - [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); - bf->elem_offset->convert_int64_to_int32(); - } -}; - -void TryCastLonglong2Int(Expr* expr) { - VLOG(6) << "Before TryCastLonglong2Int, Expr = \n" << *expr; - CheckOverflow check_overflow; - if (!check_overflow.is_overflow(expr)) { - CastLonglong2Int narrow; - narrow(expr); - } - VLOG(6) << "After TryCastLonglong2Int, Expr = \n" << *expr; -} -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/longlong2int.h b/paddle/cinn/optim/longlong2int.h deleted file mode 100644 index b72e70df603a82..00000000000000 --- a/paddle/cinn/optim/longlong2int.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) 2024 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/cinn/ir/ir.h" - -namespace cinn { -namespace optim { - -// Try to change the type of longlong to int in the expr. -void TryCastLonglong2Int(Expr* expr); -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc new file mode 100644 index 00000000000000..734878345d46c3 --- /dev/null +++ b/paddle/cinn/optim/longlong2int_pass.cc @@ -0,0 +1,230 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/longlong2int_pass.h" +#include "paddle/cinn/ir/ir_mutator.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/ir_utils.h" +#include "paddle/cinn/ir/ir_visitor.h" +#include "paddle/cinn/ir/stmt.h" +#include "paddle/cinn/ir/stmt_visitors.h" + +namespace cinn { +namespace optim { +namespace { +using ir::stmt::BlockRef; +using ir::stmt::For; +using ir::stmt::IfThenElse; +using ir::stmt::Schedule; +using ir::stmt::StmtRef; +using ir::stmt::Store; + +class CheckOverflow : public ir::stmt::StmtVisitor<> { + public: + bool operator()(const StmtRef& stmt) { + VisitStmt(stmt); + return is_overflow_; + } + bool operator()(const BlockRef& block) { + VisitBlock(block); + return is_overflow_; + } + + private: + void VisitStmt(const StmtRef& stmt) override { + if (is_overflow_) return; + ir::stmt::StmtVisitor<>::VisitStmt(stmt); + } + + void VisitStmt(const For& for_stmt) override { + if (!for_stmt->extent().is_constant()) is_overflow_ = true; + if (!for_stmt->extent().type().is_index_type()) is_overflow_ = true; + if (curr_product_ > INT_MAX) is_overflow_ = true; + + if (is_overflow_) return; + + curr_product_ *= for_stmt->extent().as_int64(); + VisitBlock(for_stmt->body()); + curr_product_ /= for_stmt->extent().as_int64(); + } + + void VisitStmt(const Schedule& schedule_stmt) override { + VisitBlock(schedule_stmt->body()); + } + + void VisitStmt(const IfThenElse& stmt) override { + VisitBlock(stmt->true_case()); + if (stmt->false_case().defined()) { + VisitBlock(stmt->false_case()); + } + } + + void VisitStmt(const ir::stmt::Let& stmt) override { return; } + void VisitStmt(const ir::stmt::Store& stmt) override { return; } + void VisitStmt(const ir::stmt::Alloc& stmt) override { return; } + void VisitStmt(const ir::stmt::Free& stmt) override { return; } + void VisitStmt(const ir::stmt::Evaluate& stmt) override { return; } + + private: + int64_t curr_product_ = 1; + bool is_overflow_ = false; +}; + +class CastLonglong2Int : public ir::IRMutator<>, + public ir::stmt::StmtMutator<> { + public: + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + void operator()(StmtRef stmt) { ir::stmt::StmtMutator<>::VisitStmt(stmt); } + void operator()(BlockRef block) { + ir::stmt::StmtMutator<>::VisitBlock(block); + } + + private: + void Visit(const ir::_Tensor_* op, Expr* expr) override { + auto node = expr->As<ir::_Tensor_>(); + std::for_each(node->shape.begin(), + node->shape.end(), + [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); + CastBufferMeta(node->buffer); + } + void Visit(const ir::Load* op, Expr* expr) override { + auto node = expr->As<ir::Load>(); + std::for_each(node->indices.begin(), + node->indices.end(), + [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); + + ir::IRMutator<>::Visit(&node->tensor, &node->tensor); + } + void Visit(const ir::Select* op, Expr* expr) override { + auto node = expr->As<ir::Select>(); + auto cond = node->condition; + if (cond.is_cmp()) { + if (cond->operand(0).is_index()) + cond->operand(0)->convert_int64_to_int32(); + if (cond->operand(1).is_index()) + cond->operand(1)->convert_int64_to_int32(); + } + ir::IRMutator<>::Visit(&node->true_value, &node->true_value); + ir::IRMutator<>::Visit(&node->false_value, &node->false_value); + } + void VisitStmt(Store stmt) override { + std::vector<Expr> indices = stmt->indices(); + std::for_each(indices.begin(), indices.end(), [&](cinn::ir::Expr& e) { + e->convert_int64_to_int32(); + }); + Expr value = stmt->value(); + Expr tensor = stmt->tensor(); + ir::IRMutator<>::Visit(&value, &value); + ir::IRMutator<>::Visit(&tensor, &tensor); + } + void VisitStmt(IfThenElse stmt) override { + Expr cond = stmt->condition(); + if (cond.is_cmp()) { + if (cond->operand(0).is_index()) + cond->operand(0)->convert_int64_to_int32(); + if (cond->operand(1).is_index()) + cond->operand(1)->convert_int64_to_int32(); + } + ir::stmt::StmtMutator<>::VisitBlock(stmt->true_case()); + if (stmt->false_case().defined()) { + ir::stmt::StmtMutator<>::VisitBlock(stmt->false_case()); + } + } + void VisitStmt(For stmt) override { + ir::Var loop_var = stmt->loop_var(); + CastVarWithBound(loop_var); + stmt->set_loop_var(loop_var); + stmt->min()->convert_int64_to_int32(); + stmt->extent()->convert_int64_to_int32(); + ir::stmt::StmtMutator<>::VisitBlock(stmt->body()); + } + void VisitStmt(Schedule stmt) override { + std::vector<Var> iter_vars = stmt->iter_vars(); + std::for_each(iter_vars.begin(), iter_vars.end(), [&](cinn::ir::Var& v) { + CastVarWithBound(v); + }); + + for (auto& buffer_range : stmt->read_buffers()) { + if (auto range = buffer_range.As<ir::_BufferRange_>()) { + std::vector<Var> ranges = range->ranges; + std::for_each(ranges.begin(), ranges.end(), [&](cinn::ir::Var& v) { + CastVarWithBound(v); + }); + auto bf = range->buffer.as_buffer_ref(); + CastBufferMeta(bf); + } + } + + for (auto& buffer_range : stmt->write_buffers()) { + if (auto range = buffer_range.As<ir::_BufferRange_>()) { + std::vector<Var> ranges = range->ranges; + + std::for_each(ranges.begin(), ranges.end(), [&](cinn::ir::Var& v) { + CastVarWithBound(v); + }); + auto bf = range->buffer.as_buffer_ref(); + CastBufferMeta(bf); + } + } + ir::stmt::StmtMutator<>::VisitBlock(stmt->body()); + } + void VisitStmt(ir::stmt::Let stmt) override { + Expr body = stmt->body(); + ir::IRMutator<>::Visit(&body, &body); + } + void VisitStmt(ir::stmt::Evaluate stmt) override { + Expr value = stmt->value(); + ir::IRMutator<>::Visit(&value, &value); + } + + void VisitStmt(ir::stmt::Alloc stmt) override { return; } + void VisitStmt(ir::stmt::Free stmt) override { return; } + + void CastVarWithBound(cinn::ir::Var& var) { // NOLINT + if (!var.defined()) return; + var->convert_int64_to_int32(); + auto lb = var->lower_bound; + auto ub = var->upper_bound; + if (lb.defined()) lb->convert_int64_to_int32(); + if (ub.defined()) ub->convert_int64_to_int32(); + } + void CastBufferMeta(cinn::ir::Buffer& bf) { // NOLINT + if (!bf.defined()) return; + std::for_each(bf->shape.begin(), bf->shape.end(), [&](cinn::ir::Expr& e) { + e->convert_int64_to_int32(); + }); + std::for_each(bf->strides.begin(), + bf->strides.end(), + [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); + bf->elem_offset->convert_int64_to_int32(); + } +}; +} // namespace + +LogicalResult LongLong2IntPass::Run(ir::stmt::StmtRef stmt) { + CastLonglong2Int narrow; + narrow(stmt); + return LogicalResult::success(); +} + +std::unique_ptr<StmtPass> CreateLongLong2IntPass() { + return std::make_unique<LongLong2IntPass>(); +} + +bool CanApplyLongLong2Int(ir::stmt::BlockRef block) { + CheckOverflow check_overflow; + return !check_overflow(block); +} +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h new file mode 100644 index 00000000000000..3441bf8fdf434e --- /dev/null +++ b/paddle/cinn/optim/longlong2int_pass.h @@ -0,0 +1,104 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/ir/stmt.h" +#include "paddle/cinn/pass/pass.h" + +namespace cinn { +namespace optim { +class LongLong2IntPass : public StmtPass { + public: + LongLong2IntPass() : StmtPass("longlong2int") {} + LogicalResult Run(ir::stmt::StmtRef stmt) override; +}; + +/** + * Converts int64 (long long) types to int32 in a block where possible. + * + * IMPORTANT: Before applying this pass, it is MANDATORY to use + * `CanApplyLongLong2Int` to check for potential overflow issues. + * + * This pass is applicable in scenarios where the IR contains int64 types that + * can be safely represented as int32 without overflow. + * + * When applied, this pass will traverse the IR and convert int64 types to int32 + * in various constructs, including: + * - Tensor shapes and indices + * - Loop variables and bounds + * - Buffer metadata (shapes, strides, offsets) + * - Comparison operations + * + * Overflow checking: + * The pass performs overflow checking primarily for nested for-loops. This + * focus on nested loops is based on the assumption that they are the most + * common source of potential overflows in typical computational kernels. The + * check considers: + * - The product of loop extents (iteration counts) + * - Whether loop bounds are constant and of index type + * + * + * Examples: + * 1. Loop variable conversion: + * Before conversion: + * { + * ScheduleBlock(root_12) + * { + * attrs(tile_method:TileFirstGeneralTactic) + * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) + * { + * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) + * { + * ScheduleBlock(var_2) + * { + * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % + * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)]) + * write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll), + * i3(0:16ll)]) + * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll] + * } + * } + * } + * } + * } + * + * After conversion: + * { + * ScheduleBlock(root_12) + * { + * attrs(tile_method:TileFirstGeneralTactic) + * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) + * { + * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) + * { + * ScheduleBlock(var_2) + * { + * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % + * 256) / 16, idx % 16) read_buffers(_var[i0(0:22), i2(0:16)]) + * write_buffers(_var_2[i0(0:22), i1(0:16), i2(0:16),i3(0:16)]) + * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16] + * } + * } + * } + * } + * } + */ +std::unique_ptr<StmtPass> CreateLongLong2IntPass(); + +// Check if the given block can be converted from long long to int, +// A.K.A. the product of the extents of all possible nested loops is within +// INT_MAX +bool CanApplyLongLong2Int(ir::stmt::BlockRef block); +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc index 10610ed0fd0361..c949d201143c02 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.cc +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -28,12 +28,14 @@ #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/cinn/ir/utils/stmt_converter.h" #include "paddle/cinn/optim/eliminate_common_factor_of_local_index.h" #include "paddle/cinn/optim/ir_simplify.h" -#include "paddle/cinn/optim/longlong2int.h" +#include "paddle/cinn/optim/longlong2int_pass.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/optim/resize_buffer.h" #include "paddle/cinn/optim/update_buffer_axis_pass.h" +#include "paddle/cinn/pass/pass_manager.h" #include "paddle/cinn/poly/isl_utils.h" #include "paddle/cinn/poly/stage.h" #include "paddle/cinn/runtime/intrinsic.h" @@ -493,7 +495,15 @@ void OptimizeExprGPU(Expr *expr) { ResizeBufferToMaxVarRange(expr); if (FLAGS_cinn_longlong2int) { - TryCastLonglong2Int(expr); + ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr); + if (CanApplyLongLong2Int(block)) { + VLOG(10) << "Before LongLong2IntPass: \n" << *expr; + StmtPassManager pass_manager; + pass_manager.AddPass(CreateLongLong2IntPass()); + pass_manager.Run(block); + *expr = ir::ConvertStmtBlockToExprBlock(block); + VLOG(10) << "After LongLong2IntPass: \n" << *expr; + } } VLOG(4) << "After Optimize Expr: \n" << *expr; From d2cbcf6e5b4ec343ba7340169959bc66040082a1 Mon Sep 17 00:00:00 2001 From: ZhouXin <zhou.xin@mail.ustc.edu.cn> Date: Fri, 27 Dec 2024 04:18:25 +0000 Subject: [PATCH 2/8] Split ll2int to tow passes --- paddle/cinn/optim/longlong2int_pass.cc | 43 +++++----- paddle/cinn/optim/longlong2int_pass.h | 99 ++++++++++++++++++---- paddle/cinn/optim/transform_gpu_forloop.cc | 10 ++- 3 files changed, 111 insertions(+), 41 deletions(-) diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc index 734878345d46c3..c9d3291be0cff3 100644 --- a/paddle/cinn/optim/longlong2int_pass.cc +++ b/paddle/cinn/optim/longlong2int_pass.cc @@ -103,9 +103,9 @@ class CastLonglong2Int : public ir::IRMutator<>, std::for_each(node->indices.begin(), node->indices.end(), [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); - ir::IRMutator<>::Visit(&node->tensor, &node->tensor); } + void Visit(const ir::Select* op, Expr* expr) override { auto node = expr->As<ir::Select>(); auto cond = node->condition; @@ -123,10 +123,6 @@ class CastLonglong2Int : public ir::IRMutator<>, std::for_each(indices.begin(), indices.end(), [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); - Expr value = stmt->value(); - Expr tensor = stmt->tensor(); - ir::IRMutator<>::Visit(&value, &value); - ir::IRMutator<>::Visit(&tensor, &tensor); } void VisitStmt(IfThenElse stmt) override { Expr cond = stmt->condition(); @@ -136,18 +132,12 @@ class CastLonglong2Int : public ir::IRMutator<>, if (cond->operand(1).is_index()) cond->operand(1)->convert_int64_to_int32(); } - ir::stmt::StmtMutator<>::VisitBlock(stmt->true_case()); - if (stmt->false_case().defined()) { - ir::stmt::StmtMutator<>::VisitBlock(stmt->false_case()); - } } void VisitStmt(For stmt) override { ir::Var loop_var = stmt->loop_var(); CastVarWithBound(loop_var); - stmt->set_loop_var(loop_var); stmt->min()->convert_int64_to_int32(); stmt->extent()->convert_int64_to_int32(); - ir::stmt::StmtMutator<>::VisitBlock(stmt->body()); } void VisitStmt(Schedule stmt) override { std::vector<Var> iter_vars = stmt->iter_vars(); @@ -155,6 +145,11 @@ class CastLonglong2Int : public ir::IRMutator<>, CastVarWithBound(v); }); + std::vector<Expr> iter_values = stmt->iter_values(); + std::for_each(iter_values.begin(), + iter_values.end(), + [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); + for (auto& buffer_range : stmt->read_buffers()) { if (auto range = buffer_range.As<ir::_BufferRange_>()) { std::vector<Var> ranges = range->ranges; @@ -179,14 +174,8 @@ class CastLonglong2Int : public ir::IRMutator<>, } ir::stmt::StmtMutator<>::VisitBlock(stmt->body()); } - void VisitStmt(ir::stmt::Let stmt) override { - Expr body = stmt->body(); - ir::IRMutator<>::Visit(&body, &body); - } - void VisitStmt(ir::stmt::Evaluate stmt) override { - Expr value = stmt->value(); - ir::IRMutator<>::Visit(&value, &value); - } + void VisitStmt(ir::stmt::Let stmt) override { return; } + void VisitStmt(ir::stmt::Evaluate stmt) override { return; } void VisitStmt(ir::stmt::Alloc stmt) override { return; } void VisitStmt(ir::stmt::Free stmt) override { return; } @@ -212,19 +201,29 @@ class CastLonglong2Int : public ir::IRMutator<>, }; } // namespace -LogicalResult LongLong2IntPass::Run(ir::stmt::StmtRef stmt) { +LogicalResult LongLong2IntStmtPass::Run(ir::stmt::StmtRef stmt) { CastLonglong2Int narrow; narrow(stmt); return LogicalResult::success(); } -std::unique_ptr<StmtPass> CreateLongLong2IntPass() { - return std::make_unique<LongLong2IntPass>(); +LogicalResult LongLong2IntExprPass::Run(ir::Expr expr) { + CastLonglong2Int narrow; + narrow(&expr); + return LogicalResult::success(); +} +std::unique_ptr<StmtPass> CreateLongLong2IntStmtPass() { + return std::make_unique<LongLong2IntStmtPass>(); +} + +std::unique_ptr<ExprPass> CreateLongLong2IntExprPass() { + return std::make_unique<LongLong2IntExprPass>(); } bool CanApplyLongLong2Int(ir::stmt::BlockRef block) { CheckOverflow check_overflow; return !check_overflow(block); } + } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h index 3441bf8fdf434e..d912ebea8ab51a 100644 --- a/paddle/cinn/optim/longlong2int_pass.h +++ b/paddle/cinn/optim/longlong2int_pass.h @@ -18,14 +18,20 @@ namespace cinn { namespace optim { -class LongLong2IntPass : public StmtPass { +class LongLong2IntStmtPass : public StmtPass { public: - LongLong2IntPass() : StmtPass("longlong2int") {} + LongLong2IntStmtPass() : StmtPass("longlong2int_stmt") {} LogicalResult Run(ir::stmt::StmtRef stmt) override; }; +class LongLong2IntExprPass : public ExprPass { + public: + LongLong2IntExprPass() : ExprPass("longlong2int_expr") {} + LogicalResult Run(ir::Expr expr) override; +}; + /** - * Converts int64 (long long) types to int32 in a block where possible. + * Converts int64 (long long) types to int32 in a Stmt where possible. * * IMPORTANT: Before applying this pass, it is MANDATORY to use * `CanApplyLongLong2Int` to check for potential overflow issues. @@ -33,21 +39,12 @@ class LongLong2IntPass : public StmtPass { * This pass is applicable in scenarios where the IR contains int64 types that * can be safely represented as int32 without overflow. * - * When applied, this pass will traverse the IR and convert int64 types to int32 + * When applied, this pass will convert int64 expression to int32 * in various constructs, including: * - Tensor shapes and indices * - Loop variables and bounds * - Buffer metadata (shapes, strides, offsets) - * - Comparison operations - * - * Overflow checking: - * The pass performs overflow checking primarily for nested for-loops. This - * focus on nested loops is based on the assumption that they are the most - * common source of potential overflows in typical computational kernels. The - * check considers: - * - The product of loop extents (iteration counts) - * - Whether loop bounds are constant and of index type - * + * - Comparison operations (index only) * * Examples: * 1. Loop variable conversion: @@ -87,14 +84,84 @@ class LongLong2IntPass : public StmtPass { * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % * 256) / 16, idx % 16) read_buffers(_var[i0(0:22), i2(0:16)]) * write_buffers(_var_2[i0(0:22), i1(0:16), i2(0:16),i3(0:16)]) - * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16] + * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll] + * } + * } + * } + * } + * } + * + * The 16ll in var[i0, i2, i3 + i1 * 16ll] is not converted for it is part of + * Load Exoression, which will be converted in LongLong2IntExprPass. + */ +std::unique_ptr<StmtPass> CreateLongLong2IntStmtPass(); + +/** + * Converts int64 (long long) types to int32 in a Stmt where possible. + * + * IMPORTANT: Before applying this pass, it is MANDATORY to use + * `CanApplyLongLong2Int` to check for potential overflow issues. + * + * This pass is applicable in scenarios where the IR contains int64 types that + * can be safely represented as int32 without overflow. + * + * When applied, this pass will convert int64 expression to int32 + * in various constructs, including: + * - Tensor shapes and indices + * - Loop variables and bounds + * - Buffer metadata (shapes, strides, offsets) + * - Comparison operations (index only) + * + * Examples: + * 1. Loop variable conversion: + * Before conversion: + * { + * ScheduleBlock(root_12) + * { + * attrs(tile_method:TileFirstGeneralTactic) + * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) + * { + * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) + * { + * ScheduleBlock(var_2) + * { + * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % + * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)]) + * write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll), + * i3(0:16ll)]) + * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll] * } * } * } * } * } + * + * After conversion: + * { + * ScheduleBlock(root_12) + * { + * attrs(tile_method:TileFirstGeneralTactic) + * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) + * { + * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) + * { + * ScheduleBlock(var_2) + * { + * i0, i1, i2, i3 = axis.bind(idx / 4096ll, (idx % 4096ll) / 256ll, + * (idx % 256ll) / 16ll, idx % 16ll) read_buffers(_var[i0(0:22ll), i2(0:16ll)]) + * write_buffers(_var_2[i0(0:22ll), i1(0:16ll), + * i2(0:16ll),i3(0:16ll)]) var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16] + * } + * } + * } + * } + * } + * + * Only 16ll in var[i0, i2, i3 + i1 * 16ll] is converted for other longlong + * Exprs are components of ScheduleBlock, which will be converted in + * LongLong2IntStmtPass. */ -std::unique_ptr<StmtPass> CreateLongLong2IntPass(); +std::unique_ptr<ExprPass> CreateLongLong2IntExprPass(); // Check if the given block can be converted from long long to int, // A.K.A. the product of the extents of all possible nested loops is within diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc index c949d201143c02..5139b64dec2861 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.cc +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -497,12 +497,16 @@ void OptimizeExprGPU(Expr *expr) { if (FLAGS_cinn_longlong2int) { ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr); if (CanApplyLongLong2Int(block)) { - VLOG(10) << "Before LongLong2IntPass: \n" << *expr; + VLOG(10) << "Before LongLong2IntStmtPass: \n" << *expr; StmtPassManager pass_manager; - pass_manager.AddPass(CreateLongLong2IntPass()); + pass_manager.AddPass(CreateLongLong2IntStmtPass()); pass_manager.Run(block); + VLOG(10) << "After LongLong2IntStmtPass: \n" << block; + ExprPassManager expr_pass_manager; + expr_pass_manager.AddPass(CreateLongLong2IntExprPass()); + expr_pass_manager.Run(block); + VLOG(10) << "After LongLong2IntExprPass: \n" << block; *expr = ir::ConvertStmtBlockToExprBlock(block); - VLOG(10) << "After LongLong2IntPass: \n" << *expr; } } From 73e89823e709cbaf2b833b83b58e05d391143119 Mon Sep 17 00:00:00 2001 From: ZhouXin <zhou.xin@mail.ustc.edu.cn> Date: Wed, 25 Dec 2024 06:49:45 +0000 Subject: [PATCH 3/8] apply cherry pick From 95def9f5470690b7283b9b38b0e0eac07f088ac4 Mon Sep 17 00:00:00 2001 From: ZhouXin <zhou.xin@mail.ustc.edu.cn> Date: Fri, 27 Dec 2024 07:28:53 +0000 Subject: [PATCH 4/8] Extract stmt logic from mutator into StmtPass --- paddle/cinn/optim/longlong2int_pass.cc | 128 ++++++++++++++----------- 1 file changed, 71 insertions(+), 57 deletions(-) diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc index c9d3291be0cff3..08b2f199943f8a 100644 --- a/paddle/cinn/optim/longlong2int_pass.cc +++ b/paddle/cinn/optim/longlong2int_pass.cc @@ -30,6 +30,25 @@ using ir::stmt::Schedule; using ir::stmt::StmtRef; using ir::stmt::Store; +void CastVarWithBound(cinn::ir::Var& var) { // NOLINT + if (!var.defined()) return; + var->convert_int64_to_int32(); + auto lb = var->lower_bound; + auto ub = var->upper_bound; + if (lb.defined()) lb->convert_int64_to_int32(); + if (ub.defined()) ub->convert_int64_to_int32(); +} +void CastBufferMeta(cinn::ir::Buffer& bf) { // NOLINT + if (!bf.defined()) return; + std::for_each(bf->shape.begin(), bf->shape.end(), [&](cinn::ir::Expr& e) { + e->convert_int64_to_int32(); + }); + std::for_each(bf->strides.begin(), bf->strides.end(), [&](cinn::ir::Expr& e) { + e->convert_int64_to_int32(); + }); + bf->elem_offset->convert_int64_to_int32(); +} + class CheckOverflow : public ir::stmt::StmtVisitor<> { public: bool operator()(const StmtRef& stmt) { @@ -81,14 +100,9 @@ class CheckOverflow : public ir::stmt::StmtVisitor<> { bool is_overflow_ = false; }; -class CastLonglong2Int : public ir::IRMutator<>, - public ir::stmt::StmtMutator<> { +class CastLonglong2IntMutator : public ir::IRMutator<> { public: void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - void operator()(StmtRef stmt) { ir::stmt::StmtMutator<>::VisitStmt(stmt); } - void operator()(BlockRef block) { - ir::stmt::StmtMutator<>::VisitBlock(block); - } private: void Visit(const ir::_Tensor_* op, Expr* expr) override { @@ -118,39 +132,50 @@ class CastLonglong2Int : public ir::IRMutator<>, ir::IRMutator<>::Visit(&node->true_value, &node->true_value); ir::IRMutator<>::Visit(&node->false_value, &node->false_value); } - void VisitStmt(Store stmt) override { - std::vector<Expr> indices = stmt->indices(); - std::for_each(indices.begin(), indices.end(), [&](cinn::ir::Expr& e) { - e->convert_int64_to_int32(); - }); - } - void VisitStmt(IfThenElse stmt) override { - Expr cond = stmt->condition(); +}; + +} // namespace + +LogicalResult LongLong2IntStmtPass::Run(ir::stmt::StmtRef stmt) { + auto CastStore = [](StmtRef stmt) { + Store store_stmt = stmt.as<Store>(); + for (Expr index : store_stmt->indices()) { + index->convert_int64_to_int32(); + } + }; + + auto CastIfThenElse = [](StmtRef stmt) { + IfThenElse if_stmt = stmt.as<IfThenElse>(); + Expr cond = if_stmt->condition(); if (cond.is_cmp()) { if (cond->operand(0).is_index()) cond->operand(0)->convert_int64_to_int32(); if (cond->operand(1).is_index()) cond->operand(1)->convert_int64_to_int32(); } - } - void VisitStmt(For stmt) override { - ir::Var loop_var = stmt->loop_var(); + }; + + auto CastFor = [](StmtRef stmt) { + For for_stmt = stmt.as<For>(); + ir::Var loop_var = for_stmt->loop_var(); CastVarWithBound(loop_var); - stmt->min()->convert_int64_to_int32(); - stmt->extent()->convert_int64_to_int32(); - } - void VisitStmt(Schedule stmt) override { - std::vector<Var> iter_vars = stmt->iter_vars(); + for_stmt->min()->convert_int64_to_int32(); + for_stmt->extent()->convert_int64_to_int32(); + }; + + auto CastSchedule = [](StmtRef stmt) { + Schedule schedule_stmt = stmt.as<Schedule>(); + std::vector<Var> iter_vars = schedule_stmt->iter_vars(); std::for_each(iter_vars.begin(), iter_vars.end(), [&](cinn::ir::Var& v) { CastVarWithBound(v); }); - std::vector<Expr> iter_values = stmt->iter_values(); + std::vector<Expr> iter_values = schedule_stmt->iter_values(); std::for_each(iter_values.begin(), iter_values.end(), [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); - for (auto& buffer_range : stmt->read_buffers()) { + for (auto& buffer_range : schedule_stmt->read_buffers()) { if (auto range = buffer_range.As<ir::_BufferRange_>()) { std::vector<Var> ranges = range->ranges; std::for_each(ranges.begin(), ranges.end(), [&](cinn::ir::Var& v) { @@ -161,7 +186,7 @@ class CastLonglong2Int : public ir::IRMutator<>, } } - for (auto& buffer_range : stmt->write_buffers()) { + for (auto& buffer_range : schedule_stmt->write_buffers()) { if (auto range = buffer_range.As<ir::_BufferRange_>()) { std::vector<Var> ranges = range->ranges; @@ -172,43 +197,32 @@ class CastLonglong2Int : public ir::IRMutator<>, CastBufferMeta(bf); } } - ir::stmt::StmtMutator<>::VisitBlock(stmt->body()); - } - void VisitStmt(ir::stmt::Let stmt) override { return; } - void VisitStmt(ir::stmt::Evaluate stmt) override { return; } - - void VisitStmt(ir::stmt::Alloc stmt) override { return; } - void VisitStmt(ir::stmt::Free stmt) override { return; } - - void CastVarWithBound(cinn::ir::Var& var) { // NOLINT - if (!var.defined()) return; - var->convert_int64_to_int32(); - auto lb = var->lower_bound; - auto ub = var->upper_bound; - if (lb.defined()) lb->convert_int64_to_int32(); - if (ub.defined()) ub->convert_int64_to_int32(); - } - void CastBufferMeta(cinn::ir::Buffer& bf) { // NOLINT - if (!bf.defined()) return; - std::for_each(bf->shape.begin(), bf->shape.end(), [&](cinn::ir::Expr& e) { - e->convert_int64_to_int32(); - }); - std::for_each(bf->strides.begin(), - bf->strides.end(), - [&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); }); - bf->elem_offset->convert_int64_to_int32(); - } -}; -} // namespace + }; -LogicalResult LongLong2IntStmtPass::Run(ir::stmt::StmtRef stmt) { - CastLonglong2Int narrow; - narrow(stmt); + switch (stmt->stmt_type()) { + case ir::StmtNodeTy::Store: + CastStore(stmt); + break; + + case ir::StmtNodeTy::IfThenElse: + CastIfThenElse(stmt); + break; + + case ir::StmtNodeTy::For: + CastFor(stmt); + break; + + case ir::StmtNodeTy::Schedule: + CastSchedule(stmt); + break; + default: + break; + } return LogicalResult::success(); } LogicalResult LongLong2IntExprPass::Run(ir::Expr expr) { - CastLonglong2Int narrow; + CastLonglong2IntMutator narrow; narrow(&expr); return LogicalResult::success(); } From baf7ccb7105e5d8eca7b838b0365a2343234e866 Mon Sep 17 00:00:00 2001 From: ZhouXin <zhou.xin@mail.ustc.edu.cn> Date: Fri, 27 Dec 2024 07:34:16 +0000 Subject: [PATCH 5/8] Refine comment --- paddle/cinn/optim/longlong2int_pass.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h index d912ebea8ab51a..ddb9f14f453d82 100644 --- a/paddle/cinn/optim/longlong2int_pass.h +++ b/paddle/cinn/optim/longlong2int_pass.h @@ -97,7 +97,7 @@ class LongLong2IntExprPass : public ExprPass { std::unique_ptr<StmtPass> CreateLongLong2IntStmtPass(); /** - * Converts int64 (long long) types to int32 in a Stmt where possible. + * Converts int64 (long long) types to int32 in a Expr where possible. * * IMPORTANT: Before applying this pass, it is MANDATORY to use * `CanApplyLongLong2Int` to check for potential overflow issues. From 82ed67384253d385d5deee28d03ae48cef7bf335 Mon Sep 17 00:00:00 2001 From: ZhouXin <zhou.xin@mail.ustc.edu.cn> Date: Mon, 30 Dec 2024 03:23:30 +0000 Subject: [PATCH 6/8] Implement CastLonglong2Int function to convert int64 types to int32 with overflow checks --- paddle/cinn/optim/longlong2int_pass.cc | 13 ++++ paddle/cinn/optim/longlong2int_pass.h | 70 ++++++++++++++++++++++ paddle/cinn/optim/transform_gpu_forloop.cc | 15 +---- paddle/cinn/pass/pass_adaptor.h | 1 + 4 files changed, 87 insertions(+), 12 deletions(-) diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc index 08b2f199943f8a..7d03515ce0990c 100644 --- a/paddle/cinn/optim/longlong2int_pass.cc +++ b/paddle/cinn/optim/longlong2int_pass.cc @@ -19,6 +19,7 @@ #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/ir/stmt.h" #include "paddle/cinn/ir/stmt_visitors.h" +#include "paddle/cinn/pass/pass_manager.h" namespace cinn { namespace optim { @@ -239,5 +240,17 @@ bool CanApplyLongLong2Int(ir::stmt::BlockRef block) { return !check_overflow(block); } +void CastLonglong2Int(ir::stmt::BlockRef block) { + if (CanApplyLongLong2Int(block)) { + StmtPassManager stmt_pass_manager; + stmt_pass_manager.AddPass(CreateLongLong2IntStmtPass()); + ExprPassManager expr_pass_manager; + expr_pass_manager.AddPass(CreateLongLong2IntExprPass()); + + stmt_pass_manager.Run(block); + expr_pass_manager.Run(block); + } +} + } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h index ddb9f14f453d82..f6c99b3df66fc9 100644 --- a/paddle/cinn/optim/longlong2int_pass.h +++ b/paddle/cinn/optim/longlong2int_pass.h @@ -167,5 +167,75 @@ std::unique_ptr<ExprPass> CreateLongLong2IntExprPass(); // A.K.A. the product of the extents of all possible nested loops is within // INT_MAX bool CanApplyLongLong2Int(ir::stmt::BlockRef block); + +/** + * Converts int64 (long long) types to int32 in a block where possible. + * + * This pass is applicable in scenarios where the IR contains int64 types that + * can be safely represented as int32 without overflow. + * + * When applied, this pass will traverse the IR and convert int64 types to int32 + * in various constructs, including: + * - Tensor shapes and indices + * - Loop variables and bounds + * - Buffer metadata (shapes, strides, offsets) + * - Comparison operations + * + * Overflow checking: + * The pass performs overflow checking primarily for nested for-loops. This + * focus on nested loops is based on the assumption that they are the most + * common source of potential overflows in typical computational kernels. The + * check considers: + * - The product of loop extents (iteration counts) + * - Whether loop bounds are constant and of index type + * + * + * Examples: + * 1. Loop variable conversion: + * Before conversion: + * { + * ScheduleBlock(root_12) + * { + * attrs(tile_method:TileFirstGeneralTactic) + * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) + * { + * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) + * { + * ScheduleBlock(var_2) + * { + * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % + * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)]) + * write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll), + * i3(0:16ll)]) + * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll] + * } + * } + * } + * } + * } + * + * After conversion: + * { + * ScheduleBlock(root_12) + * { + * attrs(tile_method:TileFirstGeneralTactic) + * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) + * { + * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) + * { + * ScheduleBlock(var_2) + * { + * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % + * 256) / 16, idx % 16) read_buffers(_var[i0(0:22), i2(0:16)]) + * write_buffers(_var_2[i0(0:22), i1(0:16), i2(0:16),i3(0:16)]) + * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16] + * } + * } + * } + * } + * } + */ +void CastLonglong2Int(ir::stmt::BlockRef block); + } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc index 5139b64dec2861..cb7be4e49d34d6 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.cc +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -496,18 +496,9 @@ void OptimizeExprGPU(Expr *expr) { if (FLAGS_cinn_longlong2int) { ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr); - if (CanApplyLongLong2Int(block)) { - VLOG(10) << "Before LongLong2IntStmtPass: \n" << *expr; - StmtPassManager pass_manager; - pass_manager.AddPass(CreateLongLong2IntStmtPass()); - pass_manager.Run(block); - VLOG(10) << "After LongLong2IntStmtPass: \n" << block; - ExprPassManager expr_pass_manager; - expr_pass_manager.AddPass(CreateLongLong2IntExprPass()); - expr_pass_manager.Run(block); - VLOG(10) << "After LongLong2IntExprPass: \n" << block; - *expr = ir::ConvertStmtBlockToExprBlock(block); - } + VLOG(10) << "Before CastLonglong2Int: \n" << block; + CastLonglong2Int(block); + VLOG(10) << "After CastLonglong2Int: \n" << block; } VLOG(4) << "After Optimize Expr: \n" << *expr; diff --git a/paddle/cinn/pass/pass_adaptor.h b/paddle/cinn/pass/pass_adaptor.h index 593660254eb3ce..19275f2875f222 100644 --- a/paddle/cinn/pass/pass_adaptor.h +++ b/paddle/cinn/pass/pass_adaptor.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/cinn/ir/utils/stmt_converter.h" #include "paddle/cinn/pass/pass.h" namespace cinn { From ca92196c73a2034824006e272959d60c84d3016e Mon Sep 17 00:00:00 2001 From: ZhouXin <zhou.xin@mail.ustc.edu.cn> Date: Mon, 30 Dec 2024 04:59:18 +0000 Subject: [PATCH 7/8] Refine --- paddle/cinn/optim/longlong2int_pass.cc | 14 +++ paddle/cinn/optim/longlong2int_pass.h | 149 ------------------------- 2 files changed, 14 insertions(+), 149 deletions(-) diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc index 7d03515ce0990c..e03649353c5579 100644 --- a/paddle/cinn/optim/longlong2int_pass.cc +++ b/paddle/cinn/optim/longlong2int_pass.cc @@ -135,6 +135,17 @@ class CastLonglong2IntMutator : public ir::IRMutator<> { } }; +class LongLong2IntStmtPass : public StmtPass { + public: + LongLong2IntStmtPass() : StmtPass("longlong2int_stmt") {} + LogicalResult Run(ir::stmt::StmtRef stmt) override; +}; + +class LongLong2IntExprPass : public ExprPass { + public: + LongLong2IntExprPass() : ExprPass("longlong2int_expr") {} + LogicalResult Run(ir::Expr expr) override; +}; } // namespace LogicalResult LongLong2IntStmtPass::Run(ir::stmt::StmtRef stmt) { @@ -235,6 +246,9 @@ std::unique_ptr<ExprPass> CreateLongLong2IntExprPass() { return std::make_unique<LongLong2IntExprPass>(); } +// Check if the given block can be converted from long long to int, +// A.K.A. the product of the extents of all possible nested loops is within +// INT_MAX bool CanApplyLongLong2Int(ir::stmt::BlockRef block) { CheckOverflow check_overflow; return !check_overflow(block); diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h index f6c99b3df66fc9..d0c35e69d6735d 100644 --- a/paddle/cinn/optim/longlong2int_pass.h +++ b/paddle/cinn/optim/longlong2int_pass.h @@ -18,155 +18,6 @@ namespace cinn { namespace optim { -class LongLong2IntStmtPass : public StmtPass { - public: - LongLong2IntStmtPass() : StmtPass("longlong2int_stmt") {} - LogicalResult Run(ir::stmt::StmtRef stmt) override; -}; - -class LongLong2IntExprPass : public ExprPass { - public: - LongLong2IntExprPass() : ExprPass("longlong2int_expr") {} - LogicalResult Run(ir::Expr expr) override; -}; - -/** - * Converts int64 (long long) types to int32 in a Stmt where possible. - * - * IMPORTANT: Before applying this pass, it is MANDATORY to use - * `CanApplyLongLong2Int` to check for potential overflow issues. - * - * This pass is applicable in scenarios where the IR contains int64 types that - * can be safely represented as int32 without overflow. - * - * When applied, this pass will convert int64 expression to int32 - * in various constructs, including: - * - Tensor shapes and indices - * - Loop variables and bounds - * - Buffer metadata (shapes, strides, offsets) - * - Comparison operations (index only) - * - * Examples: - * 1. Loop variable conversion: - * Before conversion: - * { - * ScheduleBlock(root_12) - * { - * attrs(tile_method:TileFirstGeneralTactic) - * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) - * { - * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) - * { - * ScheduleBlock(var_2) - * { - * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % - * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)]) - * write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll), - * i3(0:16ll)]) - * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll] - * } - * } - * } - * } - * } - * - * After conversion: - * { - * ScheduleBlock(root_12) - * { - * attrs(tile_method:TileFirstGeneralTactic) - * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) - * { - * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) - * { - * ScheduleBlock(var_2) - * { - * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % - * 256) / 16, idx % 16) read_buffers(_var[i0(0:22), i2(0:16)]) - * write_buffers(_var_2[i0(0:22), i1(0:16), i2(0:16),i3(0:16)]) - * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll] - * } - * } - * } - * } - * } - * - * The 16ll in var[i0, i2, i3 + i1 * 16ll] is not converted for it is part of - * Load Exoression, which will be converted in LongLong2IntExprPass. - */ -std::unique_ptr<StmtPass> CreateLongLong2IntStmtPass(); - -/** - * Converts int64 (long long) types to int32 in a Expr where possible. - * - * IMPORTANT: Before applying this pass, it is MANDATORY to use - * `CanApplyLongLong2Int` to check for potential overflow issues. - * - * This pass is applicable in scenarios where the IR contains int64 types that - * can be safely represented as int32 without overflow. - * - * When applied, this pass will convert int64 expression to int32 - * in various constructs, including: - * - Tensor shapes and indices - * - Loop variables and bounds - * - Buffer metadata (shapes, strides, offsets) - * - Comparison operations (index only) - * - * Examples: - * 1. Loop variable conversion: - * Before conversion: - * { - * ScheduleBlock(root_12) - * { - * attrs(tile_method:TileFirstGeneralTactic) - * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) - * { - * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) - * { - * ScheduleBlock(var_2) - * { - * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % - * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)]) - * write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll), - * i3(0:16ll)]) - * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll] - * } - * } - * } - * } - * } - * - * After conversion: - * { - * ScheduleBlock(root_12) - * { - * attrs(tile_method:TileFirstGeneralTactic) - * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) - * { - * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) - * { - * ScheduleBlock(var_2) - * { - * i0, i1, i2, i3 = axis.bind(idx / 4096ll, (idx % 4096ll) / 256ll, - * (idx % 256ll) / 16ll, idx % 16ll) read_buffers(_var[i0(0:22ll), i2(0:16ll)]) - * write_buffers(_var_2[i0(0:22ll), i1(0:16ll), - * i2(0:16ll),i3(0:16ll)]) var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16] - * } - * } - * } - * } - * } - * - * Only 16ll in var[i0, i2, i3 + i1 * 16ll] is converted for other longlong - * Exprs are components of ScheduleBlock, which will be converted in - * LongLong2IntStmtPass. - */ -std::unique_ptr<ExprPass> CreateLongLong2IntExprPass(); - -// Check if the given block can be converted from long long to int, -// A.K.A. the product of the extents of all possible nested loops is within -// INT_MAX -bool CanApplyLongLong2Int(ir::stmt::BlockRef block); /** * Converts int64 (long long) types to int32 in a block where possible. From b5d358772b930e6f69294b0cd0617468a7c064bc Mon Sep 17 00:00:00 2001 From: ZhouXin <zhou.xin@mail.ustc.edu.cn> Date: Mon, 30 Dec 2024 06:05:34 +0000 Subject: [PATCH 8/8] Rename CastLonglong2Int to TryCastLonglong2Int for clarity and update references --- paddle/cinn/optim/longlong2int_pass.cc | 2 +- paddle/cinn/optim/longlong2int_pass.h | 2 +- paddle/cinn/optim/transform_gpu_forloop.cc | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/cinn/optim/longlong2int_pass.cc b/paddle/cinn/optim/longlong2int_pass.cc index e03649353c5579..306b880b57c88e 100644 --- a/paddle/cinn/optim/longlong2int_pass.cc +++ b/paddle/cinn/optim/longlong2int_pass.cc @@ -254,7 +254,7 @@ bool CanApplyLongLong2Int(ir::stmt::BlockRef block) { return !check_overflow(block); } -void CastLonglong2Int(ir::stmt::BlockRef block) { +void TryCastLonglong2Int(ir::stmt::BlockRef block) { if (CanApplyLongLong2Int(block)) { StmtPassManager stmt_pass_manager; stmt_pass_manager.AddPass(CreateLongLong2IntStmtPass()); diff --git a/paddle/cinn/optim/longlong2int_pass.h b/paddle/cinn/optim/longlong2int_pass.h index d0c35e69d6735d..fa6ba61ad8b6f3 100644 --- a/paddle/cinn/optim/longlong2int_pass.h +++ b/paddle/cinn/optim/longlong2int_pass.h @@ -86,7 +86,7 @@ namespace optim { * } * } */ -void CastLonglong2Int(ir::stmt::BlockRef block); +void TryCastLonglong2Int(ir::stmt::BlockRef block); } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc index cb7be4e49d34d6..82eac4839c48e1 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.cc +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -497,8 +497,9 @@ void OptimizeExprGPU(Expr *expr) { if (FLAGS_cinn_longlong2int) { ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr); VLOG(10) << "Before CastLonglong2Int: \n" << block; - CastLonglong2Int(block); + TryCastLonglong2Int(block); VLOG(10) << "After CastLonglong2Int: \n" << block; + *expr = ir::ConvertStmtBlockToExprBlock(block); } VLOG(4) << "After Optimize Expr: \n" << *expr;