Skip to content

Commit

Permalink
Add align iter space tactic (#60498)
Browse files Browse the repository at this point in the history
Add align iter space tactic
  • Loading branch information
BiynXu authored Jan 3, 2024
1 parent deb5397 commit 698bb42
Show file tree
Hide file tree
Showing 14 changed files with 400 additions and 30 deletions.
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/pe/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <algorithm>
#include <string>

#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/hlir/op/op_util.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/lang/builtin.h"
Expand Down Expand Up @@ -216,7 +217,7 @@ ir::Tensor Reshape(const ir::Tensor& A,
}
std::vector<Expr> indice_a;
for (int i = A_expr_shape.size() - 1; i >= 0; i--) {
auto temp = offset % A_expr_shape[i];
auto temp = common::AutoSimplify(offset % A_expr_shape[i]);
indice_a.insert(indice_a.begin(), temp);
offset = (offset - temp) / A_expr_shape[i];
}
Expand Down
149 changes: 138 additions & 11 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,31 @@
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"

namespace cinn {
namespace ir {

void DynamicShapeGroupScheduler::Init() {
std::unordered_set<std::string> output_names = OutputTensorNames();
tactics_.emplace_back(new ComputeInlineTactic(output_names, target_));
tactics_.emplace_back(new ArrangeStorageTactic(output_names));
schedule_context_.output_names = OutputTensorNames();
schedule_context_.global_master = FindGlobalMasterNode();
schedule_context_.iter_space_info =
ConstructIterSpaceInfo(schedule_context_.global_master);
schedule_context_.target = target_;
tactics_.emplace_back(new AlignIterSpaceTactic());
tactics_.emplace_back(new ComputeInlineTactic());
tactics_.emplace_back(new ArrangeStorageTactic());
}

void DynamicShapeGroupScheduler::Schedule() {
// Fake schedule for test
std::vector<Expr> all_blocks = ir_sch_->GetAllBlocks();
for (int i = 0; i < all_blocks.size(); i++) {
std::vector<Expr> loops = ir_sch_->GetLoops(all_blocks[i]);
ir_sch_->Fuse(loops);
}

ApplyTactics();
all_blocks = ir_sch_->GetAllBlocks();
std::vector<Expr> all_blocks = ir_sch_->GetAllBlocks();
auto block0_loops = ir_sch_->GetLoops(all_blocks[0]);
auto splited_loops1 = ir_sch_->Split(block0_loops[0], {1024, -1});

ir_sch_->Bind(splited_loops1[0], "threadIdx.x");

ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024));
Expand All @@ -49,11 +49,22 @@ void DynamicShapeGroupScheduler::Schedule() {
void DynamicShapeGroupScheduler::ApplyTactics() {
schedule_block_graph_->Update(*ir_sch_);
for (const auto& tactic : tactics_) {
VLOG(5) << "[Start " << tactic->TacticName() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
auto ApplyTacticFunc = [&](ir::ScheduleBlockNode* node) {
VLOG(6) << "before applying [" << tactic->TacticName()
<< "] on ScheduleBlockNode [" << node->id() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
tactic->Init(&schedule_context_);
tactic->Apply(ir_sch_, node->id());
VLOG(6) << "after applying [" << tactic->TacticName()
<< "] on ScheduleBlockNode [" << node->id() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
};
schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc);
schedule_block_graph_->Update(*ir_sch_);
VLOG(5) << "[End " << tactic->TacticName()
<< "] func body: " << ir_sch_->GetModule().GetExprs().front();
}
}

Expand All @@ -67,5 +78,121 @@ DynamicShapeGroupScheduler::GetIRs() {
return irs;
}

IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo(
ScheduleBlockNode* node) {
IterativeSpaceInfo info;
std::vector<int> sp_iter_indices;
std::vector<int> rb_iter_indices;

ir::Expr block = node->Block();
std::vector<ir::Expr> iter_values =
block.As<ir::ScheduleBlockRealize>()->iter_values;
std::vector<ir::Var> iter_vars = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
std::vector<ir::Expr> loops = ir_sch_->GetLoops(block);
std::unordered_set<ir::Var> reduce_iter_vars =
analyzer::GetReduceIterVars(block);
std::unordered_map<ir::Var, ir::Expr> iter_var2value =
analyzer::GetIterVarToValueOfSBlock(block);

if (!reduce_iter_vars.empty()) {
std::set<ir::Expr> reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor(
block,
[&](const ir::Expr* x) {
bool find_reduce_var = false;
if (x->As<ir::Load>()) {
for (ir::Expr index : x->As<ir::Load>()->indices) {
if (index.as_var() &&
reduce_iter_vars.count(index.as_var_ref()) > 0) {
find_reduce_var = true;
break;
}
}
}
return find_reduce_var;
},
/* uniq_target = */ true);
CHECK_EQ(reduce_loads.size(), 1);

std::vector<ir::Expr> reduce_load_indices =
reduce_loads.begin()->As<ir::Load>()->indices;
int loop_idx = 0;
for (int i = 0; i < reduce_load_indices.size(); ++i) {
ir::Expr& index = reduce_load_indices[i];
if (index.is_constant()) continue;
CHECK_NOTNULL(index.as_var());
ir::Var iter_var = index.as_var_ref();
ir::Expr iter_value = iter_var2value.at(iter_var);
CHECK_NOTNULL(iter_value.as_var());
ir::For* for_node;
for (ir::Expr& loop : loops) {
if (loop.As<ir::For>()->loop_var == iter_value.as_var_ref()) {
for_node = loop.As<ir::For>();
}
}
CHECK_NOTNULL(for_node);
bool is_reduce_iter_var = reduce_iter_vars.count(iter_var) > 0;
if (is_reduce_iter_var) {
info.rb_space.emplace_back(for_node->extent,
IterativeSpaceInfo::AxisType::kSerial);
info.memory_consistent_order_space.emplace_back(for_node->extent);
rb_iter_indices.push_back(loop_idx);
} else {
info.sp_space.emplace_back(for_node->extent,
IterativeSpaceInfo::AxisType::kSerial);
info.memory_consistent_order_space.emplace_back(for_node->extent);
sp_iter_indices.push_back(loop_idx);
}
++loop_idx;
}
info.rb_last_order.insert(info.rb_last_order.end(),
sp_iter_indices.begin(),
sp_iter_indices.end());
info.rb_last_order.insert(info.rb_last_order.end(),
rb_iter_indices.begin(),
rb_iter_indices.end());
} else {
for (int i = 0; i < loops.size(); ++i) {
ir::For* for_node = loops[i].As<ir::For>();
info.memory_consistent_order_space.emplace_back(for_node->extent);
info.sp_space.emplace_back(for_node->extent,
IterativeSpaceInfo::AxisType::kSerial);
info.rb_last_order.push_back(i);
}
}
return info;
}

ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode() {
ir::ScheduleBlockNode* master = nullptr;
// 1. reduce
auto FindReduce = [&](ir::ScheduleBlockNode* node) {
if (analyzer::IsReductionSBlock(node->Block())) {
master = node;
}
};
schedule_block_graph_->NodesWalk(FindReduce);
if (master != nullptr) {
VLOG(6) << "Find the global master node: " << master->id();
return master;
}
// 2. broadcast
auto FindBroadcast = [&](ir::ScheduleBlockNode* node) {
if (analyzer::IsBroadcastSBlock(node->Block())) {
master = node;
}
};
schedule_block_graph_->NodesWalk(FindBroadcast);
if (master != nullptr) {
VLOG(6) << "Find the global master node: " << master->id();
return master;
}
// 3. end point
master = schedule_block_graph_->EndPoints().back();
VLOG(6) << "Find the global master node: " << master->id();
return master;
}

} // namespace ir
} // namespace cinn
5 changes: 5 additions & 0 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ class DynamicShapeGroupScheduler : public GroupScheduler {

void ApplyTactics();

ir::ScheduleBlockNode* FindGlobalMasterNode();

IterativeSpaceInfo ConstructIterSpaceInfo(ScheduleBlockNode* node);

private:
std::vector<std::pair<SymbolicPredicate, std::unique_ptr<ir::IRSchedule>>>
ir_schs_;
std::vector<std::unique_ptr<ScheduleTactic>> tactics_;
ScheduleContext schedule_context_;
};

} // namespace ir
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
gather_srcs(cinnapi_src SRCS align_iter_space_tactic.cc)
gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc)
gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
87 changes: 87 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/integer_set.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"

namespace cinn {
namespace ir {

void AlignIterSpaceTactic::Init(ScheduleContext* context) {
context_ = context;
}

void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
ir::Expr block = sch->GetBlock(block_id);
if (analyzer::IsReductionSBlock(block)) {
return;
}

std::vector<ir::Expr> loops = sch->GetLoops(block_id);
ir::Expr src_fused_loop = sch->Fuse(loops);
ir::Expr src_total_extent = src_fused_loop.As<ir::For>()->extent;

ir::Expr target_sp_extent{1};
for (const auto& iter : context_->iter_space_info.sp_space) {
target_sp_extent = target_sp_extent * std::get<0>(iter);
}
ir::Expr target_total_extent = ir_utils::IRCopy(target_sp_extent);
for (const auto& iter : context_->iter_space_info.rb_space) {
target_total_extent = target_total_extent * std::get<0>(iter);
}

common::cas_intervals_t var_intervals;
common::SymbolicExprAnalyzer symbolic_expr_analyzer(var_intervals);
std::optional<bool> total_extent_eq =
symbolic_expr_analyzer.ProveEQ(src_total_extent, target_total_extent);
bool need_reorder = false;
for (int i = 0; i < context_->iter_space_info.rb_last_order.size(); ++i) {
if (context_->iter_space_info.rb_last_order[i] != i) {
need_reorder = true;
break;
}
}

if (total_extent_eq.has_value() && total_extent_eq.value()) {
sch->Split(src_fused_loop,
context_->iter_space_info.memory_consistent_order_space);
loops = sch->GetLoops(block_id);
if (need_reorder) {
sch->Reorder(block_id, context_->iter_space_info.rb_last_order);
}
if (context_->iter_space_info.sp_space.size() < loops.size() - 1) {
loops = sch->GetLoops(block_id);
std::vector<ir::Expr> rb_loops(
loops.begin() + context_->iter_space_info.sp_space.size(),
loops.end());
sch->Fuse(rb_loops);
}
if (context_->iter_space_info.sp_space.size() > 1) {
loops = sch->GetLoops(block_id);
std::vector<ir::Expr> sp_loops(
loops.begin(),
loops.begin() + context_->iter_space_info.sp_space.size());
sch->Fuse(sp_loops);
}
}
}

} // namespace ir
} // namespace cinn
37 changes: 37 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <unordered_set>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
namespace ir {

class AlignIterSpaceTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "AlignIterSpaceTactic"; }

private:
ScheduleContext* context_;
};

} // namespace ir
} // namespace cinn
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace cinn {
namespace ir {

// [block_name, [var_name, for_node]]
// [block_name, [var, for_node]]
using VarToForMap =
std::unordered_map<std::string, std::unordered_map<ir::Var, ir::Expr>>;
using IntSet = common::SingleIntervalIntSet;
Expand Down Expand Up @@ -337,9 +337,9 @@ std::optional<CudaAxisType> AnalyzeCrossType(const VarToForMap& var2for_map,
return std::nullopt;
}

ArrangeStorageTactic::ArrangeStorageTactic(
const std::unordered_set<std::string>& output_names)
: output_names_(output_names) {}
void ArrangeStorageTactic::Init(ScheduleContext* context) {
output_names_ = context->output_names;
}

void ArrangeStorageTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ namespace ir {

class ArrangeStorageTactic final : public ScheduleTactic {
public:
explicit ArrangeStorageTactic(
const std::unordered_set<std::string>& output_names);
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "ArrangeStorageTactic"; }

private:
std::unordered_set<std::string> output_names_;
};
Expand Down
Loading

0 comments on commit 698bb42

Please sign in to comment.