Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add align iter space tactic #60498

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/pe/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <algorithm>
#include <string>

#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/hlir/op/op_util.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/lang/builtin.h"
Expand Down Expand Up @@ -216,7 +217,7 @@ ir::Tensor Reshape(const ir::Tensor& A,
}
std::vector<Expr> indice_a;
for (int i = A_expr_shape.size() - 1; i >= 0; i--) {
auto temp = offset % A_expr_shape[i];
auto temp = common::AutoSimplify(offset % A_expr_shape[i]);
indice_a.insert(indice_a.begin(), temp);
offset = (offset - temp) / A_expr_shape[i];
}
Expand Down
149 changes: 138 additions & 11 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,31 @@
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"

namespace cinn {
namespace ir {

void DynamicShapeGroupScheduler::Init() {
std::unordered_set<std::string> output_names = OutputTensorNames();
tactics_.emplace_back(new ComputeInlineTactic(output_names, target_));
tactics_.emplace_back(new ArrangeStorageTactic(output_names));
schedule_context_.output_names = OutputTensorNames();
schedule_context_.global_master = FindGlobalMasterNode();
schedule_context_.iter_space_info =
ConstructIterSpaceInfo(schedule_context_.global_master);
schedule_context_.target = target_;
tactics_.emplace_back(new AlignIterSpaceTactic());
tactics_.emplace_back(new ComputeInlineTactic());
tactics_.emplace_back(new ArrangeStorageTactic());
}

void DynamicShapeGroupScheduler::Schedule() {
// Fake schedule for test
std::vector<Expr> all_blocks = ir_sch_->GetAllBlocks();
for (int i = 0; i < all_blocks.size(); i++) {
std::vector<Expr> loops = ir_sch_->GetLoops(all_blocks[i]);
ir_sch_->Fuse(loops);
}

ApplyTactics();
all_blocks = ir_sch_->GetAllBlocks();
std::vector<Expr> all_blocks = ir_sch_->GetAllBlocks();
auto block0_loops = ir_sch_->GetLoops(all_blocks[0]);
auto splited_loops1 = ir_sch_->Split(block0_loops[0], {1024, -1});

ir_sch_->Bind(splited_loops1[0], "threadIdx.x");

ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024));
Expand All @@ -49,11 +49,22 @@ void DynamicShapeGroupScheduler::Schedule() {
void DynamicShapeGroupScheduler::ApplyTactics() {
schedule_block_graph_->Update(*ir_sch_);
for (const auto& tactic : tactics_) {
VLOG(5) << "[Start " << tactic->TacticName() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
auto ApplyTacticFunc = [&](ir::ScheduleBlockNode* node) {
VLOG(6) << "before applying [" << tactic->TacticName()
<< "] on ScheduleBlockNode [" << node->id() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
tactic->Init(&schedule_context_);
tactic->Apply(ir_sch_, node->id());
VLOG(6) << "after applying [" << tactic->TacticName()
<< "] on ScheduleBlockNode [" << node->id() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
};
schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc);
schedule_block_graph_->Update(*ir_sch_);
VLOG(5) << "[End " << tactic->TacticName()
<< "] func body: " << ir_sch_->GetModule().GetExprs().front();
}
}

Expand All @@ -67,5 +78,121 @@ DynamicShapeGroupScheduler::GetIRs() {
return irs;
}

IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo(
ScheduleBlockNode* node) {
IterativeSpaceInfo info;
std::vector<int> sp_iter_indices;
std::vector<int> rb_iter_indices;

ir::Expr block = node->Block();
std::vector<ir::Expr> iter_values =
block.As<ir::ScheduleBlockRealize>()->iter_values;
std::vector<ir::Var> iter_vars = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
std::vector<ir::Expr> loops = ir_sch_->GetLoops(block);
std::unordered_set<ir::Var> reduce_iter_vars =
analyzer::GetReduceIterVars(block);
std::unordered_map<ir::Var, ir::Expr> iter_var2value =
analyzer::GetIterVarToValueOfSBlock(block);

if (!reduce_iter_vars.empty()) {
std::set<ir::Expr> reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor(
block,
[&](const ir::Expr* x) {
bool find_reduce_var = false;
if (x->As<ir::Load>()) {
for (ir::Expr index : x->As<ir::Load>()->indices) {
if (index.as_var() &&
reduce_iter_vars.count(index.as_var_ref()) > 0) {
find_reduce_var = true;
break;
}
}
}
return find_reduce_var;
},
/* uniq_target = */ true);
CHECK_EQ(reduce_loads.size(), 1);

std::vector<ir::Expr> reduce_load_indices =
reduce_loads.begin()->As<ir::Load>()->indices;
int loop_idx = 0;
for (int i = 0; i < reduce_load_indices.size(); ++i) {
ir::Expr& index = reduce_load_indices[i];
if (index.is_constant()) continue;
CHECK_NOTNULL(index.as_var());
ir::Var iter_var = index.as_var_ref();
ir::Expr iter_value = iter_var2value.at(iter_var);
CHECK_NOTNULL(iter_value.as_var());
ir::For* for_node;
for (ir::Expr& loop : loops) {
if (loop.As<ir::For>()->loop_var == iter_value.as_var_ref()) {
for_node = loop.As<ir::For>();
}
}
CHECK_NOTNULL(for_node);
bool is_reduce_iter_var = reduce_iter_vars.count(iter_var) > 0;
if (is_reduce_iter_var) {
info.rb_space.emplace_back(for_node->extent,
IterativeSpaceInfo::AxisType::kSerial);
info.memory_consistent_order_space.emplace_back(for_node->extent);
rb_iter_indices.push_back(loop_idx);
} else {
info.sp_space.emplace_back(for_node->extent,
IterativeSpaceInfo::AxisType::kSerial);
info.memory_consistent_order_space.emplace_back(for_node->extent);
sp_iter_indices.push_back(loop_idx);
}
++loop_idx;
}
info.rb_last_order.insert(info.rb_last_order.end(),
sp_iter_indices.begin(),
sp_iter_indices.end());
info.rb_last_order.insert(info.rb_last_order.end(),
rb_iter_indices.begin(),
rb_iter_indices.end());
} else {
for (int i = 0; i < loops.size(); ++i) {
ir::For* for_node = loops[i].As<ir::For>();
info.memory_consistent_order_space.emplace_back(for_node->extent);
info.sp_space.emplace_back(for_node->extent,
IterativeSpaceInfo::AxisType::kSerial);
info.rb_last_order.push_back(i);
}
}
return info;
}

ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode() {
ir::ScheduleBlockNode* master = nullptr;
// 1. reduce
auto FindReduce = [&](ir::ScheduleBlockNode* node) {
if (analyzer::IsReductionSBlock(node->Block())) {
master = node;
}
};
schedule_block_graph_->NodesWalk(FindReduce);
if (master != nullptr) {
VLOG(6) << "Find the global master node: " << master->id();
return master;
}
// 2. broadcast
auto FindBroadcast = [&](ir::ScheduleBlockNode* node) {
if (analyzer::IsBroadcastSBlock(node->Block())) {
master = node;
}
};
schedule_block_graph_->NodesWalk(FindBroadcast);
if (master != nullptr) {
VLOG(6) << "Find the global master node: " << master->id();
return master;
}
// 3. end point
master = schedule_block_graph_->EndPoints().back();
VLOG(6) << "Find the global master node: " << master->id();
return master;
}

} // namespace ir
} // namespace cinn
5 changes: 5 additions & 0 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ class DynamicShapeGroupScheduler : public GroupScheduler {

void ApplyTactics();

ir::ScheduleBlockNode* FindGlobalMasterNode();

IterativeSpaceInfo ConstructIterSpaceInfo(ScheduleBlockNode* node);

private:
std::vector<std::pair<SymbolicPredicate, std::unique_ptr<ir::IRSchedule>>>
ir_schs_;
std::vector<std::unique_ptr<ScheduleTactic>> tactics_;
ScheduleContext schedule_context_;
};

} // namespace ir
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
gather_srcs(cinnapi_src SRCS align_iter_space_tactic.cc)
gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc)
gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
87 changes: 87 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/integer_set.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"

namespace cinn {
namespace ir {

void AlignIterSpaceTactic::Init(ScheduleContext* context) {
context_ = context;
}

void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
ir::Expr block = sch->GetBlock(block_id);
if (analyzer::IsReductionSBlock(block)) {
return;
}

std::vector<ir::Expr> loops = sch->GetLoops(block_id);
ir::Expr src_fused_loop = sch->Fuse(loops);
ir::Expr src_total_extent = src_fused_loop.As<ir::For>()->extent;

ir::Expr target_sp_extent{1};
for (const auto& iter : context_->iter_space_info.sp_space) {
target_sp_extent = target_sp_extent * std::get<0>(iter);
}
ir::Expr target_total_extent = ir_utils::IRCopy(target_sp_extent);
for (const auto& iter : context_->iter_space_info.rb_space) {
target_total_extent = target_total_extent * std::get<0>(iter);
}

common::cas_intervals_t var_intervals;
common::SymbolicExprAnalyzer symbolic_expr_analyzer(var_intervals);
std::optional<bool> total_extent_eq =
symbolic_expr_analyzer.ProveEQ(src_total_extent, target_total_extent);
bool need_reorder = false;
for (int i = 0; i < context_->iter_space_info.rb_last_order.size(); ++i) {
if (context_->iter_space_info.rb_last_order[i] != i) {
need_reorder = true;
break;
}
}

if (total_extent_eq.has_value() && total_extent_eq.value()) {
sch->Split(src_fused_loop,
context_->iter_space_info.memory_consistent_order_space);
loops = sch->GetLoops(block_id);
if (need_reorder) {
sch->Reorder(block_id, context_->iter_space_info.rb_last_order);
}
if (context_->iter_space_info.sp_space.size() < loops.size() - 1) {
loops = sch->GetLoops(block_id);
std::vector<ir::Expr> rb_loops(
loops.begin() + context_->iter_space_info.sp_space.size(),
loops.end());
sch->Fuse(rb_loops);
}
if (context_->iter_space_info.sp_space.size() > 1) {
loops = sch->GetLoops(block_id);
std::vector<ir::Expr> sp_loops(
loops.begin(),
loops.begin() + context_->iter_space_info.sp_space.size());
sch->Fuse(sp_loops);
}
}
}

} // namespace ir
} // namespace cinn
37 changes: 37 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <unordered_set>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
namespace ir {

class AlignIterSpaceTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "AlignIterSpaceTactic"; }

private:
ScheduleContext* context_;
};

} // namespace ir
} // namespace cinn
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace cinn {
namespace ir {

// [block_name, [var_name, for_node]]
// [block_name, [var, for_node]]
using VarToForMap =
std::unordered_map<std::string, std::unordered_map<ir::Var, ir::Expr>>;
using IntSet = common::SingleIntervalIntSet;
Expand Down Expand Up @@ -337,9 +337,9 @@ std::optional<CudaAxisType> AnalyzeCrossType(const VarToForMap& var2for_map,
return std::nullopt;
}

ArrangeStorageTactic::ArrangeStorageTactic(
const std::unordered_set<std::string>& output_names)
: output_names_(output_names) {}
void ArrangeStorageTactic::Init(ScheduleContext* context) {
output_names_ = context->output_names;
}

void ArrangeStorageTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ namespace ir {

class ArrangeStorageTactic final : public ScheduleTactic {
public:
explicit ArrangeStorageTactic(
const std::unordered_set<std::string>& output_names);
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "ArrangeStorageTactic"; }

private:
std::unordered_set<std::string> output_names_;
};
Expand Down
Loading