From e2b42473b2173f77c6ad05ee12b91e6ac771641b Mon Sep 17 00:00:00 2001 From: BiynXu <62832681+BiynXu@users.noreply.github.com> Date: Mon, 8 Jan 2024 11:11:40 +0800 Subject: [PATCH] [CINN]Add bucket context (#60549) * [CINN] Add tile tactic * [CINN] Add bind cuda tactic * [CINN] Add bucket contexts * fix group output args bug --- paddle/cinn/hlir/framework/op_lowering.h | 10 +- paddle/cinn/hlir/framework/op_lowering_impl.h | 9 +- .../hlir/framework/op_lowering_impl_base.h | 24 ++-- .../hlir/framework/pir/compilation_task.cc | 16 ++- .../hlir/framework/pir/compilation_task.h | 6 +- .../hlir/framework/pir/op_lowering_impl.cc | 85 +++++++------ .../hlir/framework/pir/op_lowering_impl.h | 13 +- .../dy_shape_group_scheduler.cc | 118 ++++++++++++------ .../group_schedule/dy_shape_group_scheduler.h | 18 ++- .../group_schedule/tactic/schedule_tactic.h | 3 +- .../ir/group_schedule/tactic/tile_tactic.cc | 33 +++++ paddle/cinn/ir/module.cc | 2 +- paddle/cinn/ir/module.h | 2 +- .../instruction/cinn_jit_instruction.cc | 2 - 14 files changed, 213 insertions(+), 128 deletions(-) diff --git a/paddle/cinn/hlir/framework/op_lowering.h b/paddle/cinn/hlir/framework/op_lowering.h index d4b4a78e9cd3f..f1f1554870663 100644 --- a/paddle/cinn/hlir/framework/op_lowering.h +++ b/paddle/cinn/hlir/framework/op_lowering.h @@ -47,12 +47,10 @@ class OpLowerer { group, apply_op_schedule, apply_group_schedule, apply_pass); } - std::vector< - std::pair> - BucketLower(const T& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true) { + BucketLoweredFuncsWrapper BucketLower(const T& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true) { return impl_->BucketLower( group, apply_op_schedule, apply_group_schedule, apply_pass); } diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.h b/paddle/cinn/hlir/framework/op_lowering_impl.h index d48cbbeb7e9b4..5e57c607c93e1 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/op_lowering_impl.h @@ -60,11 +60,10 @@ class OpLowererImpl : public OpLowererImplBase { bool apply_group_schedule = true, bool apply_pass = true); - std::vector> BucketLower( - const GroupPtr& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true) { + BucketLoweredFuncsWrapper BucketLower(const GroupPtr& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true) { CINN_NOT_IMPLEMENTED; } diff --git a/paddle/cinn/hlir/framework/op_lowering_impl_base.h b/paddle/cinn/hlir/framework/op_lowering_impl_base.h index 32bda3ca50f67..b67deedbbb7c5 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl_base.h +++ b/paddle/cinn/hlir/framework/op_lowering_impl_base.h @@ -27,16 +27,15 @@ namespace cinn { namespace hlir { namespace framework { +struct BucketLoweredFuncsWrapper { + std::vector> + predicate2funcs; + ir::LoweredFunc infer_shape_func; +}; + template class OpLowererImplBase { public: - struct WrapLoweredFunc { - ir::LoweredFunc kernel_func; - ir::LoweredFunc infer_shape_func; - WrapLoweredFunc(ir::LoweredFunc kernel_func, - ir::LoweredFunc infer_shape_func = ir::LoweredFunc()) - : infer_shape_func(infer_shape_func), kernel_func(kernel_func) {} - }; OpLowererImplBase() = default; ~OpLowererImplBase() = default; @@ -45,11 +44,12 @@ class OpLowererImplBase { bool apply_group_schedule = true, bool apply_pass = true) = 0; - virtual std::vector> - BucketLower(const T& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true) = 0; + virtual BucketLoweredFuncsWrapper BucketLower( + const T& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true) = 0; + virtual void InsertNameGeneToScope(std::shared_ptr scope) = 0; }; diff --git a/paddle/cinn/hlir/framework/pir/compilation_task.cc b/paddle/cinn/hlir/framework/pir/compilation_task.cc index c6d3412102c30..6be8e61600585 100644 --- a/paddle/cinn/hlir/framework/pir/compilation_task.cc +++ b/paddle/cinn/hlir/framework/pir/compilation_task.cc @@ -24,16 +24,14 @@ namespace hlir { namespace framework { void GroupCompilationContext::SetLoweredFuncs( - std::vector>&& funcs) { - for (std::pair& - predicate2func : funcs) { - predicates_.push_back(predicate2func.first); - lowered_funcs_.push_back(predicate2func.second.kernel_func); - infer_shape_lowered_funcs_.push_back( - predicate2func.second.infer_shape_func); + BucketLoweredFuncsWrapper&& funcs) { + for (std::pair& predicate2func : + funcs.predicate2funcs) { + predicates_.push_back(std::move(predicate2func.first)); + lowered_funcs_.push_back(std::move(predicate2func.second)); ++func_size_; } + infer_shape_lowered_func_ = std::move(funcs.infer_shape_func); } std::string GroupCompilationContext::PrintPredicate2Funcs() const { @@ -77,7 +75,7 @@ void CompilationTask::CodegenAndJit() { for (const ir::LoweredFunc& func : context_->lowered_funcs_) { builder.AddFunction(func); } - builder.AddInferShapeFunc(context_->infer_shape_lowered_funcs_[0]); + builder.SetInferShapeFunc(context_->infer_shape_lowered_func_); ir::Module ir_module = builder.Build(); context_->backend_compiler_ = backends::Compiler::Create(context_->target_); diff --git a/paddle/cinn/hlir/framework/pir/compilation_task.h b/paddle/cinn/hlir/framework/pir/compilation_task.h index 9e96c64694527..e76f93d206096 100644 --- a/paddle/cinn/hlir/framework/pir/compilation_task.h +++ b/paddle/cinn/hlir/framework/pir/compilation_task.h @@ -31,9 +31,7 @@ class GroupCompilationContext { std::shared_ptr scope) : target_(target), group_(group), scope_(scope) {} - void SetLoweredFuncs( - std::vector>&& funcs); + void SetLoweredFuncs(BucketLoweredFuncsWrapper&& funcs); std::string PrintPredicate2Funcs() const; void* FuncPtr(); std::shared_ptr BackendCompiler(); @@ -48,7 +46,7 @@ class GroupCompilationContext { size_t func_size_ = 0; std::vector predicates_; std::vector lowered_funcs_; - std::vector infer_shape_lowered_funcs_; + ir::LoweredFunc infer_shape_lowered_func_; std::string host_func_name_; std::string host_code_; std::vector device_code_; diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 062e5db1cc1f8..1802a1404da0a 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -99,17 +99,14 @@ std::vector OpLowererImpl::Lower(const GroupPtr& group, LOG(FATAL) << "Group Pattern Kind Is Unknown!"; } } - -std::vector> -OpLowererImpl::BucketLower(const GroupPtr& group, - bool apply_op_schedule, - bool apply_group_schedule, - bool apply_pass) { +BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(const GroupPtr& group, + bool apply_op_schedule, + bool apply_group_schedule, + bool apply_pass) { // 1.Do compute, lower and schedule for each op. auto& ops = group->ops; if (ops.size() == 1 && ops[0]->name() == "custom_call") { - return {{ir::Expr(1), - pir::OpLowererImpl::WrapLoweredFunc(LowerCustomCall(group)[0])}}; + return {{{ir::Expr(1), LowerCustomCall(group)[0]}}, ir::LoweredFunc()}; } std::vector group_func_arg_tensors; std::unordered_map<::pir::Value, ir::Tensor> tensor_map; @@ -152,24 +149,29 @@ OpLowererImpl::BucketLower(const GroupPtr& group, // 3.Do post-processing, // including preparing function args and temporary variables, // applying low-level optimization passes, etc. - std::vector> cond2funcs; + std::vector scheduled_func_bodies; for (std::pair& cond2body : cond2func_bodies) { - std::vector group_func_arg_tensors_copy = - group_func_arg_tensors; - std::vector group_func_args; - std::vector funcs = - PostProcess(group, - tensor_map, - apply_op_schedule, - cond2body.second, - &group_func_arg_tensors_copy, - &group_func_args); - ir::LoweredFunc infer_shape_func = GenerateInferShapeFunc( - group, group_func_arg_tensors_copy, group_func_args); - cond2funcs.push_back({cond2body.first, {funcs[0], infer_shape_func}}); + scheduled_func_bodies.push_back(cond2body.second); + } + std::vector group_func_arg_tensors_copy = group_func_arg_tensors; + std::vector group_func_args; + std::vector funcs = PostProcess(group, + tensor_map, + apply_op_schedule, + {scheduled_func_bodies}, + &group_func_arg_tensors_copy, + &group_func_args); + CHECK_EQ(funcs.size(), cond2func_bodies.size()); + BucketLoweredFuncsWrapper funcs_wrapper; + for (int i = 0; i < funcs.size(); ++i) { + funcs_wrapper.predicate2funcs.emplace_back(cond2func_bodies[i].first, + funcs[i]); } - return cond2funcs; + funcs_wrapper.infer_shape_func = GenerateInferShapeFunc( + group, group_func_arg_tensors_copy, group_func_args); + + return funcs_wrapper; } void OpLowererImpl::InsertNameGeneToScope(std::shared_ptr scope) { @@ -300,7 +302,7 @@ std::vector OpLowererImpl::LowerMapExpr( return PostProcess(group, *tensor_map, apply_op_schedule, - ir_sch.GetModule().GetExprs()[0], + {ir_sch.GetModule().GetExprs()[0]}, group_func_arg_tensors, &group_func_args); } @@ -355,7 +357,7 @@ std::vector OpLowererImpl::LowerGroup( return PostProcess(group, tensor_map, do_op_schedule, - ir_sch.GetModule().GetExprs().at(0), + {ir_sch.GetModule().GetExprs().at(0)}, &group_func_arg_tensors, &group_func_args); } @@ -410,7 +412,7 @@ std::vector OpLowererImpl::PostProcess( const GroupPtr& group, const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, - ir::Expr func_body, + std::vector func_bodies, std::vector* group_func_arg_tensors, std::vector* group_func_args) { // 1.Prepare function args @@ -501,23 +503,28 @@ std::vector OpLowererImpl::PostProcess( } } + std::vector lowered_funcs; + for (ir::Expr func_body : func_bodies) { #ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(func_body)); + optim::OptimizeExprGPU(&(func_body)); #endif - // 2.Prepare temp buffers - poly::StageMap stages; - auto temp_buffers = - lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body); - // 3.Building LoweredFunc - auto func = ir::_LoweredFunc_::Make( - group->FuncName(), *group_func_args, func_body, temp_buffers); - if (!done_op_schedule) { - func->PrepareBufferCastExprs(); + // 2.Prepare temp buffers + poly::StageMap stages; + auto temp_buffers = + lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body); + // 3.Building LoweredFunc + auto func = ir::_LoweredFunc_::Make( + group->FuncName(), *group_func_args, func_body, temp_buffers); + if (!done_op_schedule) { + func->PrepareBufferCastExprs(); + } + // 4.Apply low level pass + func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); + lowered_funcs.push_back(std::move(func)); } - // 4.Apply low level pass - func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); - return {func}; + + return lowered_funcs; } std::vector OpLowererImpl::LowerOps( diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h index 0a9f4d4b33820..f1ab9730a2df9 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h @@ -70,11 +70,10 @@ class OpLowererImpl : public OpLowererImplBase { * @param apply_group_schedule Whether to schedule at group level. * @return The lowered funcs. */ - std::vector> - BucketLower(const GroupPtr& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true); + BucketLoweredFuncsWrapper BucketLower(const GroupPtr& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true); void InsertNameGeneToScope(std::shared_ptr scope); @@ -108,7 +107,7 @@ class OpLowererImpl : public OpLowererImplBase { * @param tensor_map All tensors used for calculating the group. * @param done_op_schedule Mark whether the Op level schedule has been * applied. - * @param func_body The scheduled func body of group. + * @param func_bodies The scheduled func bodies of group. * @param group_func_arg_tensors Tensors used as the group function arguments. * @param group_func_args Arguments used as the group function arguments. * @return The lowered funcs after the post processing. @@ -117,7 +116,7 @@ class OpLowererImpl : public OpLowererImplBase { const GroupPtr& group, const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, - ir::Expr func_body, + std::vector func_bodies, std::vector* group_func_arg_tensors, std::vector* group_func_args); diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index 9f7a52d97fb17..657742e37ab42 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -25,16 +25,7 @@ namespace cinn { namespace ir { void DynamicShapeGroupScheduler::Init() { - // Only 1 bucket for test now. - schedule_context_.target = target_; - schedule_context_.output_names = OutputTensorNames(); - schedule_context_.global_master = FindGlobalMasterNode(); - schedule_context_.iter_space_info = - ConstructIterSpaceInfo(schedule_context_.global_master); - schedule_context_.bucket_info = {/* sp_lower_bound = */ 1024, - /* sp_upper_bound = */ INT_MAX, - /* rb_lower_bound = */ 64, - /* rb_upper_bound = */ INT_MAX}; + InitBuckets(); tactics_.emplace_back(new AlignIterSpaceTactic()); tactics_.emplace_back(new TileTactic()); tactics_.emplace_back(new ComputeInlineTactic()); @@ -42,43 +33,99 @@ void DynamicShapeGroupScheduler::Init() { tactics_.emplace_back(new ArrangeStorageTactic()); } +void DynamicShapeGroupScheduler::InitBuckets() { + std::unordered_set output_names = OutputTensorNames(); + ir::Expr fake_predicate = ir::LE::Make(Expr(1023), Expr(1024)); + auto InitBucket = [&](BucketInfo&& bucket_info) { + std::unique_ptr ir_sch = + std::make_unique(*ir_sch_); + std::unique_ptr schedule_block_graph = + std::make_unique(*ir_sch); + ir::ScheduleBlockNode* global_master = + FindGlobalMasterNode(schedule_block_graph); + IterativeSpaceInfo iter_space_info = ConstructIterSpaceInfo(global_master); + SymbolicPredicate sp_lower_bound_predicate = ir::GE::Make( + iter_space_info.total_sp_extent, ir::Expr(bucket_info.sp_lower_bound)); + SymbolicPredicate sp_upper_bound_predicate = ir::LT::Make( + iter_space_info.total_sp_extent, ir::Expr(bucket_info.sp_upper_bound)); + SymbolicPredicate rb_lower_bound_predicate = ir::GE::Make( + iter_space_info.total_rb_extent, ir::Expr(bucket_info.rb_lower_bound)); + SymbolicPredicate rb_upper_bound_predicate = ir::LT::Make( + iter_space_info.total_rb_extent, ir::Expr(bucket_info.rb_upper_bound)); + SymbolicPredicate sp_predicate = + ir::And::Make(sp_lower_bound_predicate, sp_upper_bound_predicate); + SymbolicPredicate rb_predicate = + ir::And::Make(rb_lower_bound_predicate, rb_upper_bound_predicate); + SymbolicPredicate predicate = ir::And::Make(sp_predicate, rb_predicate); + ScheduleContext schedule_context{output_names, + target_, + std::move(iter_space_info), + std::move(bucket_info)}; + BucketContext bucket_context{std::move(predicate), + std::move(ir_sch), + std::move(schedule_block_graph), + std::move(schedule_context)}; + bucket_contexts_.emplace_back(std::move(bucket_context)); + }; + // naive buckets + // 1. {sp_extent[1 - 1024], rb_extent[1 - 256]} + InitBucket({/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ 1024, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 256}); + // 2. {sp_extent[1024 - +oo], rb_extent[1 - 256]} + InitBucket({/* sp_lower_bound = */ 1024, + /* sp_upper_bound = */ INT_MAX, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 256}); + // 3. {sp_extent[1 - 1024], rb_extent[256 - +oo]} + InitBucket({/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ 1024, + /* rb_lower_bound = */ 256, + /* rb_upper_bound = */ INT_MAX}); + // 4. {sp_extent[1024 - +oo], rb_extent[256 - +oo]} + InitBucket({/* sp_lower_bound = */ 1024, + /* sp_upper_bound = */ INT_MAX, + /* rb_lower_bound = */ 256, + /* rb_upper_bound = */ INT_MAX}); +} + void DynamicShapeGroupScheduler::Schedule() { - ApplyTactics(); - // Fake bucket for test - ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024)); - std::unique_ptr new_ir_sch1 = - std::make_unique(*ir_sch_); - ir_schs_.emplace_back(predicate1, std::move(new_ir_sch1)); + for (BucketContext& bucket_context : bucket_contexts_) { + VLOG(4) << "===========================Apply tactics on Bucket [" + << bucket_context.predicate << "]=========================="; + ApplyTactics(&bucket_context); + } } -void DynamicShapeGroupScheduler::ApplyTactics() { - schedule_block_graph_->Update(*ir_sch_); +void DynamicShapeGroupScheduler::ApplyTactics(BucketContext* bucket_context) { + bucket_context->schedule_block_graph->Update(*(bucket_context->ir_sch)); for (const auto& tactic : tactics_) { VLOG(5) << "[Start " << tactic->TacticName() << "] func body:\n" - << ir_sch_->GetModule().GetExprs().front(); + << bucket_context->ir_sch->GetModule().GetExprs().front(); auto ApplyTacticFunc = [&](ir::ScheduleBlockNode* node) { VLOG(6) << "before applying [" << tactic->TacticName() << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" - << ir_sch_->GetModule().GetExprs().front(); - tactic->Apply(ir_sch_, node->id()); + << bucket_context->ir_sch->GetModule().GetExprs().front(); + tactic->Apply(bucket_context->ir_sch.get(), node->id()); VLOG(6) << "after applying [" << tactic->TacticName() << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" - << ir_sch_->GetModule().GetExprs().front(); + << bucket_context->ir_sch->GetModule().GetExprs().front(); }; - tactic->Init(&schedule_context_); - schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc); - schedule_block_graph_->Update(*ir_sch_); - VLOG(5) << "[End " << tactic->TacticName() - << "] func body: " << ir_sch_->GetModule().GetExprs().front(); + tactic->Init(&(bucket_context->schedule_context)); + bucket_context->schedule_block_graph->DFSTopoWalk(ApplyTacticFunc); + bucket_context->schedule_block_graph->Update(*(bucket_context->ir_sch)); + VLOG(5) << "[End " << tactic->TacticName() << "] func body: " + << bucket_context->ir_sch->GetModule().GetExprs().front(); } } std::vector> DynamicShapeGroupScheduler::GetIRs() { std::vector> irs; - for (auto& sch_pair : ir_schs_) { - irs.emplace_back(sch_pair.first, - sch_pair.second->GetModule().GetExprs()[0]); + for (BucketContext& context : bucket_contexts_) { + irs.emplace_back(context.predicate, + context.ir_sch->GetModule().GetExprs()[0]); } return irs; } @@ -95,7 +142,7 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo( std::vector iter_vars = block.As() ->schedule_block.As() ->iter_vars; - std::vector loops = ir_sch_->GetLoops(block); + std::vector loops = node->GetLoops(); std::unordered_set reduce_iter_vars = analyzer::GetReduceIterVars(block); std::unordered_map iter_var2value = @@ -184,7 +231,8 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo( return info; } -ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode() { +ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode( + const std::unique_ptr& schedule_block_graph) { ir::ScheduleBlockNode* master = nullptr; // 1. reduce auto FindReduce = [&](ir::ScheduleBlockNode* node) { @@ -192,7 +240,7 @@ ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode() { master = node; } }; - schedule_block_graph_->NodesWalk(FindReduce); + schedule_block_graph->NodesWalk(FindReduce); if (master != nullptr) { VLOG(6) << "Find the global master node: " << master->id(); return master; @@ -203,13 +251,13 @@ ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode() { master = node; } }; - schedule_block_graph_->NodesWalk(FindBroadcast); + schedule_block_graph->NodesWalk(FindBroadcast); if (master != nullptr) { VLOG(6) << "Find the global master node: " << master->id(); return master; } // 3. end point - master = schedule_block_graph_->EndPoints().back(); + master = schedule_block_graph->EndPoints().back(); VLOG(6) << "Find the global master node: " << master->id(); return master; } diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h index 896fe86bec852..e226059011b63 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h @@ -37,20 +37,28 @@ class DynamicShapeGroupScheduler : public GroupScheduler { std::vector> GetIRs() override; + struct BucketContext { + SymbolicPredicate predicate; + std::unique_ptr ir_sch; + std::unique_ptr schedule_block_graph; + ScheduleContext schedule_context; + }; + private: void Init(); - void ApplyTactics(); + void InitBuckets(); + + void ApplyTactics(BucketContext* bucket_context); - ir::ScheduleBlockNode* FindGlobalMasterNode(); + ir::ScheduleBlockNode* FindGlobalMasterNode( + const std::unique_ptr& schedule_block_graph); IterativeSpaceInfo ConstructIterSpaceInfo(ScheduleBlockNode* node); private: - std::vector>> - ir_schs_; + std::vector bucket_contexts_; std::vector> tactics_; - ScheduleContext schedule_context_; }; } // namespace ir diff --git a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h index 05c258b82c47c..87c387c65d817 100644 --- a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h @@ -58,9 +58,8 @@ struct BucketInfo { struct ScheduleContext { std::unordered_set output_names; - ScheduleBlockNode* global_master; - IterativeSpaceInfo iter_space_info; Target target; + IterativeSpaceInfo iter_space_info; BucketInfo bucket_info; }; diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc index 3cace2636f2d3..9586568f51f73 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc @@ -22,6 +22,7 @@ void TileTactic::Init(ScheduleContext* context) { context_ = context; // fake strategy auto GetFirstFactor = [](int num) { + if (num == 1) return 1; int factor = 1; for (int i = num - 1; i >= 1; --i) { if (num % i == 0) { @@ -32,6 +33,8 @@ void TileTactic::Init(ScheduleContext* context) { bool has_rb_iter = !context_->iter_space_info.rb_space.empty(); bool has_sp_iter = !context_->iter_space_info.sp_space.empty(); + VLOG(6) << "has_sp_iter = " << has_sp_iter + << ", has_rb_iter = " << has_rb_iter; context_->iter_space_info.rb_space.clear(); context_->iter_space_info.sp_space.clear(); @@ -40,20 +43,50 @@ void TileTactic::Init(ScheduleContext* context) { context_->iter_space_info.sp_space.emplace_back( ir::Expr(context_->bucket_info.sp_lower_bound / sp_factor), IterativeSpaceInfo::AxisType::kCudaBlockX); + VLOG(6) << "sp_space: <" + << std::get<0>(context_->iter_space_info.sp_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.sp_space.back())) + << "]>"; context_->iter_space_info.sp_space.emplace_back( ir::Expr(sp_factor), has_rb_iter ? IterativeSpaceInfo::AxisType::kCudaThreadY : IterativeSpaceInfo::AxisType::kCudaThreadX); + VLOG(6) << "sp_space: <" + << std::get<0>(context_->iter_space_info.sp_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.sp_space.back())) + << "]>"; context_->iter_space_info.sp_space.emplace_back( ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial); + VLOG(6) << "sp_space: <" + << std::get<0>(context_->iter_space_info.sp_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.sp_space.back())) + << "]>"; } if (has_rb_iter) { context_->iter_space_info.rb_space.emplace_back( ir::Expr(context_->bucket_info.rb_lower_bound), IterativeSpaceInfo::AxisType::kCudaThreadX); + VLOG(6) << "rb_space: <" + << std::get<0>(context_->iter_space_info.rb_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.rb_space.back())) + << "]>"; context_->iter_space_info.rb_space.emplace_back( ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial); + VLOG(6) << "rb_space: <" + << std::get<0>(context_->iter_space_info.rb_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.rb_space.back())) + << "]>"; } } diff --git a/paddle/cinn/ir/module.cc b/paddle/cinn/ir/module.cc index fc58e44956fe7..20298e32920fb 100644 --- a/paddle/cinn/ir/module.cc +++ b/paddle/cinn/ir/module.cc @@ -53,7 +53,7 @@ void Module::Builder::AddPredicate(ir::Expr predicate) { module_->predicates.push_back(predicate); } -void Module::Builder::AddInferShapeFunc(ir::Expr infer_shape_func) { +void Module::Builder::SetInferShapeFunc(ir::Expr infer_shape_func) { module_->infer_shape_func = infer_shape_func; } diff --git a/paddle/cinn/ir/module.h b/paddle/cinn/ir/module.h index 9910caab42b50..160d0087a0e54 100644 --- a/paddle/cinn/ir/module.h +++ b/paddle/cinn/ir/module.h @@ -45,7 +45,7 @@ class Module : public ir::IrNodeRef { void AddFunctionWithoutOptim(const ir::LoweredFunc& func); void AddBuffer(ir::Buffer buffer); void AddPredicate(ir::Expr predicate); - void AddInferShapeFunc(ir::Expr infer_shape_func); + void SetInferShapeFunc(ir::Expr infer_shape_func); void Clear(); Target::Arch GetTargetArch(); diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc index 180eb4f478fa6..d8fd3db290b33 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc @@ -50,8 +50,6 @@ class CinnJitInstruction::FnPtrImpl { } // 2. Convert arg's data about shape of Tensor to cinn_pod_value_t for (const auto& int_arg_mp : cinn_kernel_info_.int_args_map) { - func_args_.emplace_back(kernel_args[int_arg_mp.second.arg_idx]->dims().at( - int_arg_mp.second.dim_idx)); func_args_.emplace_back(static_cast( kernel_args[int_arg_mp.second.arg_idx]->dims().at( int_arg_mp.second.dim_idx)));