Skip to content

Commit

Permalink
[CINN]Add bucket context (#60549)
Browse files Browse the repository at this point in the history
* [CINN] Add tile tactic

* [CINN] Add bind cuda tactic

* [CINN] Add bucket contexts

* fix group output args bug
  • Loading branch information
BiynXu authored Jan 8, 2024
1 parent 385ec43 commit e2b4247
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 128 deletions.
10 changes: 4 additions & 6 deletions paddle/cinn/hlir/framework/op_lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@ class OpLowerer {
group, apply_op_schedule, apply_group_schedule, apply_pass);
}

std::vector<
std::pair<ir::SymbolicPredicate, pir::OpLowererImpl::WrapLoweredFunc>>
BucketLower(const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
BucketLoweredFuncsWrapper BucketLower(const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
return impl_->BucketLower(
group, apply_op_schedule, apply_group_schedule, apply_pass);
}
Expand Down
9 changes: 4 additions & 5 deletions paddle/cinn/hlir/framework/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,10 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
bool apply_group_schedule = true,
bool apply_pass = true);

std::vector<std::pair<ir::SymbolicPredicate, WrapLoweredFunc>> BucketLower(
const GroupPtr& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
BucketLoweredFuncsWrapper BucketLower(const GroupPtr& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
CINN_NOT_IMPLEMENTED;
}

Expand Down
24 changes: 12 additions & 12 deletions paddle/cinn/hlir/framework/op_lowering_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,15 @@ namespace cinn {
namespace hlir {
namespace framework {

struct BucketLoweredFuncsWrapper {
std::vector<std::pair<ir::SymbolicPredicate, ir::LoweredFunc>>
predicate2funcs;
ir::LoweredFunc infer_shape_func;
};

template <typename T>
class OpLowererImplBase {
public:
struct WrapLoweredFunc {
ir::LoweredFunc kernel_func;
ir::LoweredFunc infer_shape_func;
WrapLoweredFunc(ir::LoweredFunc kernel_func,
ir::LoweredFunc infer_shape_func = ir::LoweredFunc())
: infer_shape_func(infer_shape_func), kernel_func(kernel_func) {}
};
OpLowererImplBase() = default;
~OpLowererImplBase() = default;

Expand All @@ -45,11 +44,12 @@ class OpLowererImplBase {
bool apply_group_schedule = true,
bool apply_pass = true) = 0;

virtual std::vector<std::pair<ir::SymbolicPredicate, WrapLoweredFunc>>
BucketLower(const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) = 0;
virtual BucketLoweredFuncsWrapper BucketLower(
const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) = 0;

virtual void InsertNameGeneToScope(std::shared_ptr<Scope> scope) = 0;
};

Expand Down
16 changes: 7 additions & 9 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ namespace hlir {
namespace framework {

void GroupCompilationContext::SetLoweredFuncs(
std::vector<std::pair<ir::SymbolicPredicate,
pir::OpLowererImpl::WrapLoweredFunc>>&& funcs) {
for (std::pair<ir::SymbolicPredicate, pir::OpLowererImpl::WrapLoweredFunc>&
predicate2func : funcs) {
predicates_.push_back(predicate2func.first);
lowered_funcs_.push_back(predicate2func.second.kernel_func);
infer_shape_lowered_funcs_.push_back(
predicate2func.second.infer_shape_func);
BucketLoweredFuncsWrapper&& funcs) {
for (std::pair<ir::SymbolicPredicate, ir::LoweredFunc>& predicate2func :
funcs.predicate2funcs) {
predicates_.push_back(std::move(predicate2func.first));
lowered_funcs_.push_back(std::move(predicate2func.second));
++func_size_;
}
infer_shape_lowered_func_ = std::move(funcs.infer_shape_func);
}

std::string GroupCompilationContext::PrintPredicate2Funcs() const {
Expand Down Expand Up @@ -77,7 +75,7 @@ void CompilationTask::CodegenAndJit() {
for (const ir::LoweredFunc& func : context_->lowered_funcs_) {
builder.AddFunction(func);
}
builder.AddInferShapeFunc(context_->infer_shape_lowered_funcs_[0]);
builder.SetInferShapeFunc(context_->infer_shape_lowered_func_);
ir::Module ir_module = builder.Build();

context_->backend_compiler_ = backends::Compiler::Create(context_->target_);
Expand Down
6 changes: 2 additions & 4 deletions paddle/cinn/hlir/framework/pir/compilation_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ class GroupCompilationContext {
std::shared_ptr<Scope> scope)
: target_(target), group_(group), scope_(scope) {}

void SetLoweredFuncs(
std::vector<std::pair<ir::SymbolicPredicate,
pir::OpLowererImpl::WrapLoweredFunc>>&& funcs);
void SetLoweredFuncs(BucketLoweredFuncsWrapper&& funcs);
std::string PrintPredicate2Funcs() const;
void* FuncPtr();
std::shared_ptr<backends::Compiler> BackendCompiler();
Expand All @@ -48,7 +46,7 @@ class GroupCompilationContext {
size_t func_size_ = 0;
std::vector<ir::SymbolicPredicate> predicates_;
std::vector<ir::LoweredFunc> lowered_funcs_;
std::vector<ir::LoweredFunc> infer_shape_lowered_funcs_;
ir::LoweredFunc infer_shape_lowered_func_;
std::string host_func_name_;
std::string host_code_;
std::vector<std::string> device_code_;
Expand Down
85 changes: 46 additions & 39 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,14 @@ std::vector<ir::LoweredFunc> OpLowererImpl::Lower(const GroupPtr& group,
LOG(FATAL) << "Group Pattern Kind Is Unknown!";
}
}

std::vector<std::pair<ir::SymbolicPredicate, OpLowererImpl::WrapLoweredFunc>>
OpLowererImpl::BucketLower(const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule,
bool apply_pass) {
BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule,
bool apply_pass) {
// 1.Do compute, lower and schedule for each op.
auto& ops = group->ops;
if (ops.size() == 1 && ops[0]->name() == "custom_call") {
return {{ir::Expr(1),
pir::OpLowererImpl::WrapLoweredFunc(LowerCustomCall(group)[0])}};
return {{{ir::Expr(1), LowerCustomCall(group)[0]}}, ir::LoweredFunc()};
}
std::vector<ir::Tensor> group_func_arg_tensors;
std::unordered_map<::pir::Value, ir::Tensor> tensor_map;
Expand Down Expand Up @@ -152,24 +149,29 @@ OpLowererImpl::BucketLower(const GroupPtr& group,
// 3.Do post-processing,
// including preparing function args and temporary variables,
// applying low-level optimization passes, etc.
std::vector<std::pair<ir::Expr, WrapLoweredFunc>> cond2funcs;
std::vector<ir::Expr> scheduled_func_bodies;
for (std::pair<ir::SymbolicPredicate, ir::Expr>& cond2body :
cond2func_bodies) {
std::vector<ir::Tensor> group_func_arg_tensors_copy =
group_func_arg_tensors;
std::vector<ir::Argument> group_func_args;
std::vector<ir::LoweredFunc> funcs =
PostProcess(group,
tensor_map,
apply_op_schedule,
cond2body.second,
&group_func_arg_tensors_copy,
&group_func_args);
ir::LoweredFunc infer_shape_func = GenerateInferShapeFunc(
group, group_func_arg_tensors_copy, group_func_args);
cond2funcs.push_back({cond2body.first, {funcs[0], infer_shape_func}});
scheduled_func_bodies.push_back(cond2body.second);
}
std::vector<ir::Tensor> group_func_arg_tensors_copy = group_func_arg_tensors;
std::vector<ir::Argument> group_func_args;
std::vector<ir::LoweredFunc> funcs = PostProcess(group,
tensor_map,
apply_op_schedule,
{scheduled_func_bodies},
&group_func_arg_tensors_copy,
&group_func_args);
CHECK_EQ(funcs.size(), cond2func_bodies.size());
BucketLoweredFuncsWrapper funcs_wrapper;
for (int i = 0; i < funcs.size(); ++i) {
funcs_wrapper.predicate2funcs.emplace_back(cond2func_bodies[i].first,
funcs[i]);
}
return cond2funcs;
funcs_wrapper.infer_shape_func = GenerateInferShapeFunc(
group, group_func_arg_tensors_copy, group_func_args);

return funcs_wrapper;
}

void OpLowererImpl::InsertNameGeneToScope(std::shared_ptr<Scope> scope) {
Expand Down Expand Up @@ -300,7 +302,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::LowerMapExpr(
return PostProcess(group,
*tensor_map,
apply_op_schedule,
ir_sch.GetModule().GetExprs()[0],
{ir_sch.GetModule().GetExprs()[0]},
group_func_arg_tensors,
&group_func_args);
}
Expand Down Expand Up @@ -355,7 +357,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::LowerGroup(
return PostProcess(group,
tensor_map,
do_op_schedule,
ir_sch.GetModule().GetExprs().at(0),
{ir_sch.GetModule().GetExprs().at(0)},
&group_func_arg_tensors,
&group_func_args);
}
Expand Down Expand Up @@ -410,7 +412,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
const GroupPtr& group,
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
bool done_op_schedule,
ir::Expr func_body,
std::vector<ir::Expr> func_bodies,
std::vector<ir::Tensor>* group_func_arg_tensors,
std::vector<ir::Argument>* group_func_args) {
// 1.Prepare function args
Expand Down Expand Up @@ -501,23 +503,28 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
}
}

std::vector<ir::LoweredFunc> lowered_funcs;
for (ir::Expr func_body : func_bodies) {
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(func_body));
optim::OptimizeExprGPU(&(func_body));
#endif

// 2.Prepare temp buffers
poly::StageMap stages;
auto temp_buffers =
lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body);
// 3.Building LoweredFunc
auto func = ir::_LoweredFunc_::Make(
group->FuncName(), *group_func_args, func_body, temp_buffers);
if (!done_op_schedule) {
func->PrepareBufferCastExprs();
// 2.Prepare temp buffers
poly::StageMap stages;
auto temp_buffers =
lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body);
// 3.Building LoweredFunc
auto func = ir::_LoweredFunc_::Make(
group->FuncName(), *group_func_args, func_body, temp_buffers);
if (!done_op_schedule) {
func->PrepareBufferCastExprs();
}
// 4.Apply low level pass
func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref();
lowered_funcs.push_back(std::move(func));
}
// 4.Apply low level pass
func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref();
return {func};

return lowered_funcs;
}

std::vector<ir::Expr> OpLowererImpl::LowerOps(
Expand Down
13 changes: 6 additions & 7 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,10 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
* @param apply_group_schedule Whether to schedule at group level.
* @return The lowered funcs.
*/
std::vector<std::pair<ir::SymbolicPredicate, OpLowererImpl::WrapLoweredFunc>>
BucketLower(const GroupPtr& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true);
BucketLoweredFuncsWrapper BucketLower(const GroupPtr& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true);

void InsertNameGeneToScope(std::shared_ptr<Scope> scope);

Expand Down Expand Up @@ -108,7 +107,7 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
* @param tensor_map All tensors used for calculating the group.
* @param done_op_schedule Mark whether the Op level schedule has been
* applied.
* @param func_body The scheduled func body of group.
* @param func_bodies The scheduled func bodies of group.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @param group_func_args Arguments used as the group function arguments.
* @return The lowered funcs after the post processing.
Expand All @@ -117,7 +116,7 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
const GroupPtr& group,
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
bool done_op_schedule,
ir::Expr func_body,
std::vector<ir::Expr> func_bodies,
std::vector<ir::Tensor>* group_func_arg_tensors,
std::vector<ir::Argument>* group_func_args);

Expand Down
Loading

0 comments on commit e2b4247

Please sign in to comment.