Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN]Add bucket context #60549

Merged
merged 7 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions paddle/cinn/hlir/framework/op_lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@ class OpLowerer {
group, apply_op_schedule, apply_group_schedule, apply_pass);
}

std::vector<
std::pair<ir::SymbolicPredicate, pir::OpLowererImpl::WrapLoweredFunc>>
BucketLower(const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
BucketLoweredFuncsWrapper BucketLower(const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
return impl_->BucketLower(
group, apply_op_schedule, apply_group_schedule, apply_pass);
}
Expand Down
9 changes: 4 additions & 5 deletions paddle/cinn/hlir/framework/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,10 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
bool apply_group_schedule = true,
bool apply_pass = true);

std::vector<std::pair<ir::SymbolicPredicate, WrapLoweredFunc>> BucketLower(
const GroupPtr& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
BucketLoweredFuncsWrapper BucketLower(const GroupPtr& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
CINN_NOT_IMPLEMENTED;
}

Expand Down
24 changes: 12 additions & 12 deletions paddle/cinn/hlir/framework/op_lowering_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,15 @@ namespace cinn {
namespace hlir {
namespace framework {

struct BucketLoweredFuncsWrapper {
std::vector<std::pair<ir::SymbolicPredicate, ir::LoweredFunc>>
predicate2funcs;
ir::LoweredFunc infer_shape_func;
};

template <typename T>
class OpLowererImplBase {
public:
struct WrapLoweredFunc {
ir::LoweredFunc kernel_func;
ir::LoweredFunc infer_shape_func;
WrapLoweredFunc(ir::LoweredFunc kernel_func,
ir::LoweredFunc infer_shape_func = ir::LoweredFunc())
: infer_shape_func(infer_shape_func), kernel_func(kernel_func) {}
};
OpLowererImplBase() = default;
~OpLowererImplBase() = default;

Expand All @@ -45,11 +44,12 @@ class OpLowererImplBase {
bool apply_group_schedule = true,
bool apply_pass = true) = 0;

virtual std::vector<std::pair<ir::SymbolicPredicate, WrapLoweredFunc>>
BucketLower(const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) = 0;
virtual BucketLoweredFuncsWrapper BucketLower(
const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) = 0;

virtual void InsertNameGeneToScope(std::shared_ptr<Scope> scope) = 0;
};

Expand Down
16 changes: 7 additions & 9 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ namespace hlir {
namespace framework {

void GroupCompilationContext::SetLoweredFuncs(
std::vector<std::pair<ir::SymbolicPredicate,
pir::OpLowererImpl::WrapLoweredFunc>>&& funcs) {
for (std::pair<ir::SymbolicPredicate, pir::OpLowererImpl::WrapLoweredFunc>&
predicate2func : funcs) {
predicates_.push_back(predicate2func.first);
lowered_funcs_.push_back(predicate2func.second.kernel_func);
infer_shape_lowered_funcs_.push_back(
predicate2func.second.infer_shape_func);
BucketLoweredFuncsWrapper&& funcs) {
for (std::pair<ir::SymbolicPredicate, ir::LoweredFunc>& predicate2func :
funcs.predicate2funcs) {
predicates_.push_back(std::move(predicate2func.first));
lowered_funcs_.push_back(std::move(predicate2func.second));
++func_size_;
}
infer_shape_lowered_func_ = std::move(funcs.infer_shape_func);
}

std::string GroupCompilationContext::PrintPredicate2Funcs() const {
Expand Down Expand Up @@ -77,7 +75,7 @@ void CompilationTask::CodegenAndJit() {
for (const ir::LoweredFunc& func : context_->lowered_funcs_) {
builder.AddFunction(func);
}
builder.AddInferShapeFunc(context_->infer_shape_lowered_funcs_[0]);
builder.SetInferShapeFunc(context_->infer_shape_lowered_func_);
ir::Module ir_module = builder.Build();

context_->backend_compiler_ = backends::Compiler::Create(context_->target_);
Expand Down
6 changes: 2 additions & 4 deletions paddle/cinn/hlir/framework/pir/compilation_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ class GroupCompilationContext {
std::shared_ptr<Scope> scope)
: target_(target), group_(group), scope_(scope) {}

void SetLoweredFuncs(
std::vector<std::pair<ir::SymbolicPredicate,
pir::OpLowererImpl::WrapLoweredFunc>>&& funcs);
void SetLoweredFuncs(BucketLoweredFuncsWrapper&& funcs);
std::string PrintPredicate2Funcs() const;
void* FuncPtr();
std::shared_ptr<backends::Compiler> BackendCompiler();
Expand All @@ -48,7 +46,7 @@ class GroupCompilationContext {
size_t func_size_ = 0;
std::vector<ir::SymbolicPredicate> predicates_;
std::vector<ir::LoweredFunc> lowered_funcs_;
std::vector<ir::LoweredFunc> infer_shape_lowered_funcs_;
ir::LoweredFunc infer_shape_lowered_func_;
std::string host_func_name_;
std::string host_code_;
std::vector<std::string> device_code_;
Expand Down
85 changes: 46 additions & 39 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,14 @@ std::vector<ir::LoweredFunc> OpLowererImpl::Lower(const GroupPtr& group,
LOG(FATAL) << "Group Pattern Kind Is Unknown!";
}
}

std::vector<std::pair<ir::SymbolicPredicate, OpLowererImpl::WrapLoweredFunc>>
OpLowererImpl::BucketLower(const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule,
bool apply_pass) {
BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule,
bool apply_pass) {
// 1.Do compute, lower and schedule for each op.
auto& ops = group->ops;
if (ops.size() == 1 && ops[0]->name() == "custom_call") {
return {{ir::Expr(1),
pir::OpLowererImpl::WrapLoweredFunc(LowerCustomCall(group)[0])}};
return {{{ir::Expr(1), LowerCustomCall(group)[0]}}, ir::LoweredFunc()};
}
std::vector<ir::Tensor> group_func_arg_tensors;
std::unordered_map<::pir::Value, ir::Tensor> tensor_map;
Expand Down Expand Up @@ -152,24 +149,29 @@ OpLowererImpl::BucketLower(const GroupPtr& group,
// 3.Do post-processing,
// including preparing function args and temporary variables,
// applying low-level optimization passes, etc.
std::vector<std::pair<ir::Expr, WrapLoweredFunc>> cond2funcs;
std::vector<ir::Expr> scheduled_func_bodies;
for (std::pair<ir::SymbolicPredicate, ir::Expr>& cond2body :
cond2func_bodies) {
std::vector<ir::Tensor> group_func_arg_tensors_copy =
group_func_arg_tensors;
std::vector<ir::Argument> group_func_args;
std::vector<ir::LoweredFunc> funcs =
PostProcess(group,
tensor_map,
apply_op_schedule,
cond2body.second,
&group_func_arg_tensors_copy,
&group_func_args);
ir::LoweredFunc infer_shape_func = GenerateInferShapeFunc(
group, group_func_arg_tensors_copy, group_func_args);
cond2funcs.push_back({cond2body.first, {funcs[0], infer_shape_func}});
scheduled_func_bodies.push_back(cond2body.second);
}
std::vector<ir::Tensor> group_func_arg_tensors_copy = group_func_arg_tensors;
std::vector<ir::Argument> group_func_args;
std::vector<ir::LoweredFunc> funcs = PostProcess(group,
tensor_map,
apply_op_schedule,
{scheduled_func_bodies},
&group_func_arg_tensors_copy,
&group_func_args);
CHECK_EQ(funcs.size(), cond2func_bodies.size());
BucketLoweredFuncsWrapper funcs_wrapper;
for (int i = 0; i < funcs.size(); ++i) {
funcs_wrapper.predicate2funcs.emplace_back(cond2func_bodies[i].first,
funcs[i]);
}
return cond2funcs;
funcs_wrapper.infer_shape_func = GenerateInferShapeFunc(
group, group_func_arg_tensors_copy, group_func_args);

return funcs_wrapper;
}

void OpLowererImpl::InsertNameGeneToScope(std::shared_ptr<Scope> scope) {
Expand Down Expand Up @@ -300,7 +302,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::LowerMapExpr(
return PostProcess(group,
*tensor_map,
apply_op_schedule,
ir_sch.GetModule().GetExprs()[0],
{ir_sch.GetModule().GetExprs()[0]},
group_func_arg_tensors,
&group_func_args);
}
Expand Down Expand Up @@ -355,7 +357,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::LowerGroup(
return PostProcess(group,
tensor_map,
do_op_schedule,
ir_sch.GetModule().GetExprs().at(0),
{ir_sch.GetModule().GetExprs().at(0)},
&group_func_arg_tensors,
&group_func_args);
}
Expand Down Expand Up @@ -410,7 +412,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
const GroupPtr& group,
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
bool done_op_schedule,
ir::Expr func_body,
std::vector<ir::Expr> func_bodies,
std::vector<ir::Tensor>* group_func_arg_tensors,
std::vector<ir::Argument>* group_func_args) {
// 1.Prepare function args
Expand Down Expand Up @@ -501,23 +503,28 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
}
}

std::vector<ir::LoweredFunc> lowered_funcs;
for (ir::Expr func_body : func_bodies) {
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(func_body));
optim::OptimizeExprGPU(&(func_body));
#endif

// 2.Prepare temp buffers
poly::StageMap stages;
auto temp_buffers =
lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body);
// 3.Building LoweredFunc
auto func = ir::_LoweredFunc_::Make(
group->FuncName(), *group_func_args, func_body, temp_buffers);
if (!done_op_schedule) {
func->PrepareBufferCastExprs();
// 2.Prepare temp buffers
poly::StageMap stages;
auto temp_buffers =
lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body);
// 3.Building LoweredFunc
auto func = ir::_LoweredFunc_::Make(
group->FuncName(), *group_func_args, func_body, temp_buffers);
if (!done_op_schedule) {
func->PrepareBufferCastExprs();
}
// 4.Apply low level pass
func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref();
lowered_funcs.push_back(std::move(func));
}
// 4.Apply low level pass
func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref();
return {func};

return lowered_funcs;
}

std::vector<ir::Expr> OpLowererImpl::LowerOps(
Expand Down
13 changes: 6 additions & 7 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,10 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
* @param apply_group_schedule Whether to schedule at group level.
* @return The lowered funcs.
*/
std::vector<std::pair<ir::SymbolicPredicate, OpLowererImpl::WrapLoweredFunc>>
BucketLower(const GroupPtr& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true);
BucketLoweredFuncsWrapper BucketLower(const GroupPtr& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true);

void InsertNameGeneToScope(std::shared_ptr<Scope> scope);

Expand Down Expand Up @@ -108,7 +107,7 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
* @param tensor_map All tensors used for calculating the group.
* @param done_op_schedule Mark whether the Op level schedule has been
* applied.
* @param func_body The scheduled func body of group.
* @param func_bodies The scheduled func bodies of group.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @param group_func_args Arguments used as the group function arguments.
* @return The lowered funcs after the post processing.
Expand All @@ -117,7 +116,7 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
const GroupPtr& group,
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
bool done_op_schedule,
ir::Expr func_body,
std::vector<ir::Expr> func_bodies,
std::vector<ir::Tensor>* group_func_arg_tensors,
std::vector<ir::Argument>* group_func_args);

Expand Down
Loading