Skip to content

Commit

Permalink
[CINN] New tiling method optimized for warp-level continuous read (#6…
Browse files Browse the repository at this point in the history
…4240)

* add new tile apply function

* replace index of local buffer
  • Loading branch information
lshpku authored Jun 3, 2024
1 parent 252b746 commit 72e8d4e
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 0 deletions.
120 changes: 120 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,26 @@ bool IsWarpReduce(const ScheduleConfig& config) {
return std::visit(MatchWarpReduce, config.tile_config.reduce_method);
}

bool UseReduceTile(const ScheduleConfig& config) {
const auto& raw_reduce_axis = config.base_info->raw_reduce_axis;
const auto raw_data_rank = config.base_info->raw_data_rank;
if (raw_reduce_axis.empty()) {
return false;
}
for (size_t i = 1; i < raw_reduce_axis.size(); i++) {
if (raw_reduce_axis[i] != raw_reduce_axis[i - 1] + 1) {
return false;
}
}
return raw_reduce_axis.back() + 1 == raw_data_rank;
}

class TileFirstGeneralTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;
void ApplyReduceTile(ir::IRSchedule* sch, const std::string& block_id);

std::string TacticName() const override { return "TileFirstGeneralTactic"; }

Expand Down Expand Up @@ -98,6 +113,11 @@ void TileFirstGeneralTactic::Init(ScheduleContext* context) {

void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
if (UseReduceTile(context_->config)) {
VLOG(4) << "Using ApplyReduceTile";
ApplyReduceTile(sch, block_id);
return;
}
if (ir::IsReduceInitTensorName(block_id)) return;
MergeReduceAxis(sch, block_id);
VLOG(6) << "After MergeReduceAxis on block: [" << block_id
Expand Down Expand Up @@ -136,6 +156,106 @@ void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch,
SetReduceType(sch, block_id);
}

void TileFirstGeneralTactic::ApplyReduceTile(ir::IRSchedule* sch,
const std::string& block_id) {
if (ir::IsReduceInitTensorName(block_id)) return;

const auto sp_thread = context_->config.tile_config.warp_num * 32 /
context_->config.tile_config.tree_reduce_num;
const auto sp_loop = context_->config.tile_config.spatial_inner_num;
const auto rd_thread = context_->config.tile_config.tree_reduce_num;
VLOG(4) << "ApplyReduceTile sp_thread=" << sp_thread;
VLOG(4) << "ApplyReduceTile sp_loop=" << sp_loop;
VLOG(4) << "ApplyReduceTile rd_thread=" << rd_thread;
VLOG(4) << "ApplyReduceTile vec_flatten_axis: "
<< utils::Join(vec_flatten_axis_, ", ");
VLOG(4) << "ApplyReduceTile vec_reduce_axis: "
<< utils::Join(vec_reduce_axis_, ", ");

// Merge reduce axes
MergeReduceAxis(sch, block_id);
VLOG(4) << "After MergeReduceAxis on block: [" << block_id
<< "], loop nest:\n"
<< sch->GetModule().GetExprs().front();

// Merge spatial axes
MergeFlattenAxis(sch, block_id);
VLOG(4) << "After MergeFlattenAxis on block: [" << block_id
<< "], loop nest:\n"
<< sch->GetModule().GetExprs().front();

// Split spatial axes -> [sp_block, sp_loop, sp_thread]
int current_reduce_axis = 0;
if (vec_flatten_axis_.size() > 0) {
auto loops = sch->GetLoops(block_id);
if (sp_loop > 1 && sp_thread > 1) {
sch->Split(loops[0], {-1, sp_loop, sp_thread});
current_reduce_axis = 3;
} else if (sp_loop > 1 || sp_thread > 1) {
sch->Split(loops[0], {-1, sp_loop > 1 ? sp_loop : sp_thread});
current_reduce_axis = 2;
} else {
current_reduce_axis = 1;
}
}
VLOG(4) << "After SplitSptial on block: [" << block_id << "], loop nest:\n"
<< sch->GetModule().GetExprs().front();

// Split reduce axes -> [rd_loop, rd_thread]
if (vec_reduce_axis_.size() > 0) {
auto loops = sch->GetLoops(block_id);
auto reduce_loop = loops[current_reduce_axis].As<ir::For>();
sch->Split(loops[current_reduce_axis], {-1, rd_thread});
VLOG(4) << "Before ReorderReduction on block: [" << block_id
<< "], loop nest:\n"
<< sch->GetModule().GetExprs().front();

// TODO(lshpku): the Reorder is unneeded if the later FactorizeReduction
// supports rf_axis=1.
loops = sch->GetLoops(block_id);
sch->Reorder({loops[current_reduce_axis + 1], loops[current_reduce_axis]});
VLOG(4) << "Before FactorizeReduction on block: [" << block_id
<< "], loop nest:\n"
<< sch->GetModule().GetExprs().front();

if (IsReduceBlock(context_->config, block_id)) {
loops = sch->GetLoops(block_id);
sch->FactorizeReduction(loops[current_reduce_axis],
/* rf_axis = */ 0,
/* with_write_back_block_init = */ false);
}
}
VLOG(4) << "After SplitReduce on block: [" << block_id << "], loop nest:\n"
<< sch->GetModule().GetExprs().front();

// Bind CUDA info
const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
std::string sp_axis_type = "threadIdx.y";
std::string rd_axis_type = "threadIdx.x";
sch->Bind(loops[0], "blockIdx.x");
if (!vec_flatten_axis_.empty() && sp_thread > 1) {
if (vec_reduce_axis_.empty()) {
sch->Bind(loops[current_reduce_axis - 1], rd_axis_type);
} else {
sch->Bind(loops[current_reduce_axis - 1], sp_axis_type);
}
}
if (!vec_reduce_axis_.empty() && current_reduce_axis > 0) {
sch->Bind(loops[current_reduce_axis], rd_axis_type);
}
};
DoBind(sch->GetLoops(block_id));
if (IsReduceBlock(context_->config, block_id) &&
sch->HasBlock(block_id + "_rf")) {
DoBind(sch->GetLoops(block_id + "_rf"));
}
VLOG(4) << "After BindCudaInfo on block: [" << block_id << "], loop nest:\n"
<< sch->GetModule().GetExprs().front();

VariableTypeAssignment(sch, block_id);
SetReduceType(sch, block_id);
}

void TileFirstGeneralTactic::MergeFlattenAxis(ir::IRSchedule* sch,
const std::string& block_id) {
if (vec_flatten_axis_.size() >= 2) {
Expand Down
31 changes: 31 additions & 0 deletions paddle/cinn/optim/resize_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> {
ir::Store* store = expr->As<ir::Store>();
ir::Tensor tensor = store->tensor.as_tensor_ref();
ResizeTensor(&tensor);
ReplaceTensorIndices<ir::Store>(store);
ir::IRMutator<>::Visit(op, expr);
}

Expand Down Expand Up @@ -277,6 +278,7 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> {
for (int i = 0; i < cnt; i++) {
load->indices.erase(load->indices.begin());
}
ReplaceTensorIndices<ir::Load>(load);
ir::IRMutator<>::Visit(op, expr);
}

Expand Down Expand Up @@ -304,6 +306,35 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> {
}
}

template <typename T>
void ReplaceTensorIndices(T* op) {
ir::Tensor tensor = op->tensor.as_tensor_ref();
ir::Buffer buffer = tensor->buffer;
if (!buffer.defined()) return;
if (buffer->memory_type != ir::MemoryType::GPULocal) return;

VLOG(4) << "replacing index of tensor: " << tensor->name;
ir::Expr index_expr = op->index();
std::unordered_map<std::string, ir::Expr> var_name_to_expr;
ir::ir_utils::CollectIRNodes(index_expr, [&](const ir::Expr* x) {
const ir::_Var_* var = x->as_var();
if (var) {
var_name_to_expr[var->name] = var->Copy();
}
return false;
});
if (var_name_to_expr.size() != 1) {
return;
}

ir::Expr single_var = var_name_to_expr.begin()->second;
VLOG(4) << "found single var: " << single_var;
for (size_t i = 0; i + 1 < op->indices.size(); i++) {
op->indices[i] = ir::Expr(0);
}
op->indices.back() = single_var;
}

private:
const std::unordered_map<std::string, std::vector<ir::Expr>>&
buffer_name_to_shape_;
Expand Down

0 comments on commit 72e8d4e

Please sign in to comment.