Skip to content

Commit

Permalink
[CINN] Add tile tactic and bind cuda tactic (#60534)
Browse files Browse the repository at this point in the history
* [CINN] Add tile tactic

* [CINN] Add bind cuda tactic
  • Loading branch information
BiynXu authored Jan 5, 2024
1 parent c3106c4 commit bc13117
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 9 deletions.
36 changes: 28 additions & 8 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,36 @@
#include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/op/ir_operators.h"

namespace cinn {
namespace ir {

void DynamicShapeGroupScheduler::Init() {
// Only 1 bucket for test now.
schedule_context_.target = target_;
schedule_context_.output_names = OutputTensorNames();
schedule_context_.global_master = FindGlobalMasterNode();
schedule_context_.iter_space_info =
ConstructIterSpaceInfo(schedule_context_.global_master);
schedule_context_.target = target_;
schedule_context_.bucket_info = {/* sp_lower_bound = */ 1024,
/* sp_upper_bound = */ INT_MAX,
/* rb_lower_bound = */ 64,
/* rb_upper_bound = */ INT_MAX};
tactics_.emplace_back(new AlignIterSpaceTactic());
tactics_.emplace_back(new TileTactic());
tactics_.emplace_back(new ComputeInlineTactic());
tactics_.emplace_back(new BindCudaTactic());
tactics_.emplace_back(new ArrangeStorageTactic());
}

void DynamicShapeGroupScheduler::Schedule() {
// Fake schedule for test
ApplyTactics();
std::vector<Expr> all_blocks = ir_sch_->GetAllBlocks();
auto block0_loops = ir_sch_->GetLoops(all_blocks[0]);
auto splited_loops1 = ir_sch_->Split(block0_loops[0], {1024, -1});
ir_sch_->Bind(splited_loops1[0], "threadIdx.x");

// Fake bucket for test
ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024));
std::unique_ptr<ir::IRSchedule> new_ir_sch1 =
std::make_unique<ir::IRSchedule>(*ir_sch_);
Expand All @@ -55,12 +60,12 @@ void DynamicShapeGroupScheduler::ApplyTactics() {
VLOG(6) << "before applying [" << tactic->TacticName()
<< "] on ScheduleBlockNode [" << node->id() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
tactic->Init(&schedule_context_);
tactic->Apply(ir_sch_, node->id());
VLOG(6) << "after applying [" << tactic->TacticName()
<< "] on ScheduleBlockNode [" << node->id() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
};
tactic->Init(&schedule_context_);
schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc);
schedule_block_graph_->Update(*ir_sch_);
VLOG(5) << "[End " << tactic->TacticName()
Expand Down Expand Up @@ -96,6 +101,7 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo(
std::unordered_map<ir::Var, ir::Expr> iter_var2value =
analyzer::GetIterVarToValueOfSBlock(block);

// init iter info
if (!reduce_iter_vars.empty()) {
std::set<ir::Expr> reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor(
block,
Expand Down Expand Up @@ -161,6 +167,20 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo(
info.rb_last_order.push_back(i);
}
}
// init total extents
ir::Expr sp_extent = ir::Expr(1);
ir::Expr rb_extent = ir::Expr(1);
for (const auto& axis : info.sp_space) {
const ir::Expr& extent = std::get<0>(axis);
sp_extent = sp_extent * extent;
}
for (const auto& axis : info.rb_space) {
const ir::Expr& extent = std::get<0>(axis);
rb_extent = rb_extent * extent;
}
info.total_sp_extent = sp_extent;
info.total_rb_extent = rb_extent;

return info;
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS align_iter_space_tactic.cc)
gather_srcs(cinnapi_src SRCS tile_tactic.cc)
gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc)
gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc)
gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#pragma once

#include <string>
#include <unordered_set>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
Expand Down
58 changes: 58 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h"
#include <unordered_map>
#include "paddle/cinn/ir/ir.h"

namespace cinn {
namespace ir {

void BindCudaTactic::Init(ScheduleContext* context) { context_ = context; }

const std::unordered_map<IterativeSpaceInfo::AxisType, std::string>
axis_type2bind_info = {
{IterativeSpaceInfo::AxisType::kCudaBlockX, "blockIdx.x"},
{IterativeSpaceInfo::AxisType::kCudaBlockY, "blockIdx.y"},
{IterativeSpaceInfo::AxisType::kCudaBlockZ, "blockIdx.z"},
{IterativeSpaceInfo::AxisType::kCudaThreadX, "threadIdx.x"},
{IterativeSpaceInfo::AxisType::kCudaThreadY, "threadIdx.y"},
{IterativeSpaceInfo::AxisType::kCudaThreadZ, "threadIdx.z"},
};

void BindCudaTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) {
std::vector<ir::Expr> loops = sch->GetLoops(block_id);
int loop_idx = 0;
for (int i = 0;
i < context_->iter_space_info.sp_space.size() && loop_idx < loops.size();
++i, ++loop_idx) {
const auto& axis = context_->iter_space_info.sp_space[i];
const IterativeSpaceInfo::AxisType& axis_type = std::get<1>(axis);
if (axis_type2bind_info.count(axis_type) != 0) {
sch->Bind(loops[loop_idx], axis_type2bind_info.at(axis_type));
}
}
for (int i = 0;
i < context_->iter_space_info.rb_space.size() && loop_idx < loops.size();
++i, ++loop_idx) {
const auto& axis = context_->iter_space_info.rb_space[i];
const IterativeSpaceInfo::AxisType& axis_type = std::get<1>(axis);
if (axis_type2bind_info.count(axis_type) != 0) {
sch->Bind(loops[loop_idx], axis_type2bind_info.at(axis_type));
}
}
}

} // namespace ir
} // namespace cinn
36 changes: 36 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
namespace ir {

class BindCudaTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "BindCudaTactic"; }

private:
ScheduleContext* context_;
};

} // namespace ir
} // namespace cinn
12 changes: 12 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ struct IterativeSpaceInfo {
std::vector<std::tuple<ir::Expr, AxisType>> sp_space;
// reduce or broadcast iterative space
std::vector<std::tuple<ir::Expr, AxisType>> rb_space;
// total sp extent
ir::Expr total_sp_extent;
// total rb extent
ir::Expr total_rb_extent;
// original loop order with same iteration order as the memory order
std::vector<ir::Expr> memory_consistent_order_space;
// index that transform from memory consistent order to rb last order
Expand All @@ -45,11 +49,19 @@ struct IterativeSpaceInfo {
std::vector<int> rb_last_order;
};

struct BucketInfo {
int sp_lower_bound = 0;
int sp_upper_bound = UINT_MAX;
int rb_lower_bound = 0;
int rb_upper_bound = UINT_MAX;
};

struct ScheduleContext {
std::unordered_set<std::string> output_names;
ScheduleBlockNode* global_master;
IterativeSpaceInfo iter_space_info;
Target target;
BucketInfo bucket_info;
};

class ScheduleTactic {
Expand Down
84 changes: 84 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h"
#include "paddle/cinn/ir/ir.h"

namespace cinn {
namespace ir {

void TileTactic::Init(ScheduleContext* context) {
context_ = context;
// fake strategy
auto GetFirstFactor = [](int num) {
int factor = 1;
for (int i = num - 1; i >= 1; --i) {
if (num % i == 0) {
return i;
}
}
};

bool has_rb_iter = !context_->iter_space_info.rb_space.empty();
bool has_sp_iter = !context_->iter_space_info.sp_space.empty();
context_->iter_space_info.rb_space.clear();
context_->iter_space_info.sp_space.clear();

if (has_sp_iter) {
int sp_factor = GetFirstFactor(context_->bucket_info.sp_lower_bound);
context_->iter_space_info.sp_space.emplace_back(
ir::Expr(context_->bucket_info.sp_lower_bound / sp_factor),
IterativeSpaceInfo::AxisType::kCudaBlockX);
context_->iter_space_info.sp_space.emplace_back(
ir::Expr(sp_factor),
has_rb_iter ? IterativeSpaceInfo::AxisType::kCudaThreadY
: IterativeSpaceInfo::AxisType::kCudaThreadX);
context_->iter_space_info.sp_space.emplace_back(
ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial);
}

if (has_rb_iter) {
context_->iter_space_info.rb_space.emplace_back(
ir::Expr(context_->bucket_info.rb_lower_bound),
IterativeSpaceInfo::AxisType::kCudaThreadX);
context_->iter_space_info.rb_space.emplace_back(
ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial);
}
}

void TileTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) {
std::vector<ir::Expr> loops = sch->GetLoops(block_id);
CHECK(loops.size() == 1 || loops.size() == 2)
<< "All loops must be unified as sp_loop or rb_loop.";
if (loops.size() == 2) {
std::vector<ir::Expr> rb_factors;
for (const auto& axis : context_->iter_space_info.rb_space) {
rb_factors.push_back(std::get<0>(axis));
}
sch->Split(loops[1], rb_factors);
loops = sch->GetLoops(block_id);
VLOG(6) << "after split rb loop of " << block_id << ": "
<< sch->GetModule().GetExprs()[0];
}
std::vector<ir::Expr> sp_factors;
for (const auto& axis : context_->iter_space_info.sp_space) {
sp_factors.push_back(std::get<0>(axis));
}
sch->Split(loops[0], sp_factors);
VLOG(6) << "after split sp loop of " << block_id << ": "
<< sch->GetModule().GetExprs()[0];
}

} // namespace ir
} // namespace cinn
36 changes: 36 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/tile_tactic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
namespace ir {

class TileTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "TileTactic"; }

private:
ScheduleContext* context_;
};

} // namespace ir
} // namespace cinn

0 comments on commit bc13117

Please sign in to comment.