diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index 657742e37ab42..628ad3fbfef9d 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -17,6 +17,7 @@ #include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h" +#include "paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h" #include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" #include "paddle/cinn/ir/op/ir_operators.h" @@ -27,8 +28,9 @@ namespace ir { void DynamicShapeGroupScheduler::Init() { InitBuckets(); tactics_.emplace_back(new AlignIterSpaceTactic()); - tactics_.emplace_back(new TileTactic()); tactics_.emplace_back(new ComputeInlineTactic()); + tactics_.emplace_back(new TileTactic()); + tactics_.emplace_back(new OptimizeReductionTactic()); tactics_.emplace_back(new BindCudaTactic()); tactics_.emplace_back(new ArrangeStorageTactic()); } diff --git a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt index b12e669b8c2d0..e8205f7244bb1 100644 --- a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt +++ b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt @@ -3,5 +3,6 @@ core_gather_headers() gather_srcs(cinnapi_src SRCS align_iter_space_tactic.cc) gather_srcs(cinnapi_src SRCS tile_tactic.cc) gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc) +gather_srcs(cinnapi_src SRCS optimize_reduction_tactic.cc) gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc) gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc) diff --git a/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc new file mode 100644 index 0000000000000..1721e99d657f6 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc @@ -0,0 +1,126 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" + +namespace cinn { +namespace ir { + +void OptimizeReductionTactic::Init(ScheduleContext* context) { + context_ = context; +} + +bool CanApply(const std::string& block_name, ir::IRSchedule* sch) { + ir::Expr block_expr = sch->GetBlock(block_name); + ir::ScheduleBlockRealize* block_realize = + block_expr.As(); + CHECK_NOTNULL(block_realize); + ir::ScheduleBlock* sch_block = + block_realize->schedule_block.As(); + CHECK_NOTNULL(sch_block); + analyzer::AnalyzeScheduleBlockReadWriteBuffer(sch_block); + + // 1. The block must have write buffer + if (sch_block->write_buffers.empty()) { + return false; + } + + // 2. The block must have at least one reduce axis + const std::vector& iter_vars = sch_block->iter_vars; + bool find_reduce_axis = false; + for (int i = 0; i < iter_vars.size(); ++i) { + if (iter_vars[i]->is_reduce_axis) { + find_reduce_axis = true; + break; + } + } + if (!find_reduce_axis) { + return false; + } + + // 3. Each loop's body only contains one sub loop or block, except reduce_init + // block + std::vector loops = sch->GetLoops(block_name); + for (const ir::Expr& loop : loops) { + const ir::Expr& body = loop.As()->body; + if (body.As()) { + if (body.As()->stmts.size() == 1) { + if (body.As()->stmts[0].As() == nullptr && + body.As()->stmts[0].As() == + nullptr) { + return false; + } + } else if (body.As()->stmts.size() == 2) { + if (body.As()->stmts[0].As() == + nullptr || + !ir::IsReduceInitTensorName( + analyzer::GetBlockName(body.As()->stmts[0]))) { + return false; + } + if (body.As()->stmts[1].As() == nullptr && + body.As()->stmts[1].As() == + nullptr) { + return false; + } + } else { + return false; + } + } else if (body.As() || body.As()) { + continue; + } else { + return false; + } + } + + return true; +} + +void OptimizeReductionTactic::Apply(ir::IRSchedule* sch, + const std::string& block_id) { + if (!CanApply(block_id, sch)) return; + + std::vector loops = sch->GetLoops(block_id); + int first_reduce_loop_idx = context_->iter_space_info.sp_space.size(); + CHECK_LT(first_reduce_loop_idx, loops.size()) + << "first_reduce_loop_idx shoud be less than number of loop."; + // Apply FactorizeReduction + VLOG(6) << "before FactorizeReduction: " << sch->GetModule().GetExprs()[0]; + sch->FactorizeReduction(loops[first_reduce_loop_idx], first_reduce_loop_idx); + VLOG(6) << "after FactorizeReduction: " << sch->GetModule().GetExprs()[0]; + + // Loop fusion and cross thread reduction + std::vector rb_loops = sch->GetLoops(block_id); + std::string rf_block_id = block_id + "_rf"; + ir::Expr rf_block = sch->GetBlock(rf_block_id); + sch->SimpleComputeAt(rf_block, rb_loops.back()); + + rb_loops = sch->GetLoops(block_id); + ir::Expr rf_init_block = + sch->GetBlock(ir::GenReduceInitTensorNameOf(rf_block_id)); + sch->SimpleComputeAt(rf_init_block, rb_loops.back()); + + if (context_->target == cinn::common::DefaultNVGPUTarget()) { + rb_loops = sch->GetLoops(block_id); + rf_block = sch->GetBlock(rf_block_id); + sch->Bind(rb_loops.back(), "threadIdx.x"); + sch->SetBuffer(rf_block, "shared"); + } + VLOG(6) << "Loop fusion and cross thread reduction: " + << sch->GetModule().GetExprs()[0]; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h new file mode 100644 index 0000000000000..108f674ee2253 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h @@ -0,0 +1,36 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" + +namespace cinn { +namespace ir { + +class OptimizeReductionTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "OptimizeReductionTactic"; } + + private: + ScheduleContext* context_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc b/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc index 701d003dbcd2d..b75f12712853f 100644 --- a/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc +++ b/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc @@ -23,6 +23,8 @@ #include #include +#include "paddle/cinn/common/context.h" +#include "paddle/cinn/common/integer_set.h" #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/ir_visitor.h" @@ -440,6 +442,74 @@ bool IsBroadcastSBlock(ir::Expr block) { return load->indices.size() < store->indices.size(); } +std::vector IndicesToVars(const std::vector& indices) { + std::vector result; + for (const ir::Expr& e : indices) { + if (e.is_constant()) { + std::string var_name = + cinn::UniqName("constant" + static_cast(e.get_constant())); + result.emplace_back(e, e, var_name, /* is_reduce = */ false); + } else if (e.As() != nullptr) { + ir::Expr copy_e = ir::ir_utils::IRCopy(e); + ir::_Var_* var_ref = copy_e.As(); + result.emplace_back(ir::Var(var_ref)); + } else { + std::string var_name = cinn::UniqName("expr"); + common::cas_intervals_t var_intervals; + bool is_reduce = false; + ir::ir_utils::CollectIRNodes(e, [&](const ir::Expr* x) { + if (x->As() != nullptr) { + ir::Var var = x->as_var_ref(); + var_intervals.insert( + {var->name, + common::CasInterval{var->lower_bound, var->upper_bound}}); + if (var->is_reduce_axis) is_reduce = true; + } + return false; + }); + common::SymbolicExprAnalyzer analyzer(var_intervals); + result.emplace_back( + analyzer.LowerBound(e), analyzer.UpperBound(e), var_name, is_reduce); + } + } + return result; +} + +void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) { + if (!sche_block->read_buffers.empty() || !sche_block->write_buffers.empty()) { + return; + } + + ir::ir_utils::CollectIRNodesWithoutTensor( + sche_block->body, [&](const Expr* x) { + const ir::Load* load_expr = x->As(); + if (load_expr != nullptr) { + const ir::Tensor t = load_expr->tensor.as_tensor_ref(); + sche_block->read_buffers.emplace_back( + ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices))); + return false; + } + const ir::Store* store_expr = x->As(); + if (store_expr != nullptr) { + const ir::Tensor t = store_expr->tensor.as_tensor_ref(); + sche_block->write_buffers.emplace_back( + ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices))); + return false; + } + return false; + }); +} + +std::string GetBlockName(const ir::Expr block) { + const ir::ScheduleBlockRealize* block_realize = + block.As(); + CHECK_NOTNULL(block_realize); + const ir::ScheduleBlock* block_node = + block_realize->schedule_block.As(); + CHECK_NOTNULL(block_node); + return block_node->name; +} + } // namespace analyzer } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/ir_analyzer/ir_analyzer.h b/paddle/cinn/ir/ir_analyzer/ir_analyzer.h index 50f3a4eafaf2d..03997607fd90e 100644 --- a/paddle/cinn/ir/ir_analyzer/ir_analyzer.h +++ b/paddle/cinn/ir/ir_analyzer/ir_analyzer.h @@ -73,6 +73,12 @@ bool IsReductionSBlock(ir::Expr block); bool IsBroadcastSBlock(ir::Expr block); +std::vector IndicesToVars(const std::vector& indices); + +void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block); + +std::string GetBlockName(const ir::Expr block); + } // namespace analyzer } // namespace ir } // namespace cinn