From 83cdf5e01cb0b0aff41974d9b47d4e552e8f247e Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 27 Sep 2023 04:40:00 +0000 Subject: [PATCH 01/12] add --- .../pir_adaptor/pir_adaptor_util.cc | 154 ++++++++++++++++++ .../pir_adaptor/pir_adaptor_util.h | 71 +++++--- .../pir/transforms/pd_op_to_kernel_pass.cc | 68 +++++--- 3 files changed, 253 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index 25b68432f4bc7..f2aa098d40eab 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -125,6 +125,160 @@ std::string ValueExecutionInfo::GetNameById(int id) const { } return ""; } + +const std::unordered_map<::pir::Value, std::string>& +ValueExecutionInfo::GetValue2VarName() const { + return value_2_var_name_; +} + +void ValueExecutionInfo::AddValue2VarName(::pir::Value value, + const std::string& var_name) { + value_2_var_name_.emplace(value, var_name); +} + +const std::unordered_map& +ValueExecutionInfo::GetVar2VarName() const { + return var_2_var_name_; +} + +const std::map& ValueExecutionInfo::GetVarName2Id() const { + return var_name_2_id_; +} + +const std::unordered_map& ValueExecutionInfo::GetId2VarName() + const { + return id_2_var_name_; +} + +const std::vector& ValueExecutionInfo::GetVarList() const { + return var_list_; +} + +void ValueExecutionInfo::ResetVarList(int id, Variable* var) { + var_list_[id] = var; +} + +bool ValueExecutionInfo::HasValue(::pir::Value value) const { + return HasValueInternal(value); +} + +bool ValueExecutionInfo::HasLocalValue(::pir::Value value) const { + return HasValueLocally(value); +} + +std::string ValueExecutionInfo::GetVarName(::pir::Value value) const { + return GetVarNameInternal(value); +} + +std::string ValueExecutionInfo::GetVarName(const Variable* var) const { + return GetVarNameInternal(var); +} + +std::string ValueExecutionInfo::GetLocalVarName(::pir::Value value) const { + return GetVarNameLocally(value); +} + +std::string ValueExecutionInfo::GetLocalVarName(const Variable* var) const { + return GetVarNameLocally(var); +} + +int ValueExecutionInfo::GetVarId(::pir::Value value) const { + return GetVarIdInternal(value); +} + +int ValueExecutionInfo::GetVarId(const Variable* var) const { + return GetVarIdInternal(var); +} + +int ValueExecutionInfo::GetLocalVarId(::pir::Value value) const { + return GetVarIdLocally(value); +} + +int ValueExecutionInfo::GetLocalVarId(const Variable* var) const { + return GetVarIdLocally(var); +} + +bool ValueExecutionInfo::HasValueInternal(::pir::Value value) const { + if (HasValueLocally(value)) { + return true; + } + return (parent_ == nullptr) ? false : parent_->HasValueInternal(value); +} + +bool ValueExecutionInfo::HasValueLocally(::pir::Value value) const { + auto it = value_2_var_name_.find(value); + if (it != value_2_var_name_.end()) { + return true; + } + return false; +} + +std::string ValueExecutionInfo::GetVarNameInternal(::pir::Value value) const { + auto name = GetVarNameLocally(value); + if (name != "") { + return name; + } + return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(value); +} + +std::string ValueExecutionInfo::GetVarNameLocally(::pir::Value value) const { + auto it = value_2_var_name_.find(value); + if (it != value_2_var_name_.end()) { + return it->second; + } + return ""; +} + +std::string ValueExecutionInfo::GetVarNameInternal(const Variable* var) const { + auto name = GetVarNameLocally(var); + if (name != "") { + return name; + } + return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(var); +} + +std::string ValueExecutionInfo::GetVarNameLocally(const Variable* var) const { + auto it = var_2_var_name_.find(var); + if (it != var_2_var_name_.end()) { + return it->second; + } + return ""; +} + +int ValueExecutionInfo::GetVarIdInternal(::pir::Value value) const { + auto id = GetVarIdLocally(value); + if (id != -1) { + return id; + } + return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(value); +} + +int ValueExecutionInfo::GetVarIdLocally(::pir::Value value) const { + auto var_name = GetVarNameLocally(value); + auto it = var_name_2_id_.find(var_name); + if (it != var_name_2_id_.end()) { + return it->second; + } + return -1; +} + +int ValueExecutionInfo::GetVarIdInternal(const Variable* var) const { + auto id = GetVarIdLocally(var); + if (id != -1) { + return id; + } + return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(var); +} + +int ValueExecutionInfo::GetVarIdLocally(const Variable* var) const { + auto var_name = GetVarNameLocally(var); + auto it = var_name_2_id_.find(var_name); + if (it != var_name_2_id_.end()) { + return it->second; + } + return -1; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index 6801960a14017..82f93a417a32f 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -48,11 +48,13 @@ namespace framework { class CondInstruction; class ValueExecutionInfo { public: + friend class CondInstruction; + explicit ValueExecutionInfo(Scope* scope) : scope_(scope) {} const ValueExecutionInfo* Parent() const { return parent_; } - Scope* GetScope() { return scope_; } + Scope* GetScope() const { return scope_; } void Add(::pir::Value value, std::string var_name); @@ -62,35 +64,64 @@ class ValueExecutionInfo { std::string GetNameById(int id) const; - const std::unordered_map<::pir::Value, std::string>& GetValue2VarName() - const { - return value_2_var_name_; - } + const std::unordered_map<::pir::Value, std::string>& GetValue2VarName() const; - void AddValue2VarName(::pir::Value value, const std::string& var_name) { - value_2_var_name_.emplace(value, var_name); - } + void AddValue2VarName(::pir::Value value, const std::string& var_name); const std::unordered_map& - GetVar2VarName() const { - return var_2_var_name_; - } + GetVar2VarName() const; - const std::map& GetVarName2Id() const { - return var_name_2_id_; - } + const std::map& GetVarName2Id() const; - const std::unordered_map& GetId2VarName() const { - return id_2_var_name_; - } + const std::unordered_map& GetId2VarName() const; - const std::vector& GetVarList() const { return var_list_; } + const std::vector& GetVarList() const; - void ResetVarList(int id, Variable* var) { var_list_[id] = var; } + void ResetVarList(int id, Variable* var); - friend class CondInstruction; + /// Check a value exist in the ValueExecutionInfo or any of its ancestors. + bool HasValue(::pir::Value value) const; + + /// Check a value exist in the ValueExecutionInfo. + bool HasLocalValue(::pir::Value value) const; + + std::string GetVarName(::pir::Value value) const; + + std::string GetVarName(const Variable* var) const; + + std::string GetLocalVarName(::pir::Value value) const; + + std::string GetLocalVarName(const Variable* var) const; + + int GetVarId(::pir::Value value) const; + + int GetVarId(const Variable* var) const; + + int GetLocalVarId(::pir::Value value) const; + + int GetLocalVarId(const Variable* var) const; private: + bool HasValueInternal(::pir::Value value) const; + + bool HasValueLocally(::pir::Value value) const; + + std::string GetVarNameInternal(::pir::Value value) const; + + std::string GetVarNameLocally(::pir::Value value) const; + + std::string GetVarNameInternal(const Variable* var) const; + + std::string GetVarNameLocally(const Variable* var) const; + + int GetVarIdInternal(::pir::Value value) const; + + int GetVarIdLocally(::pir::Value value) const; + + int GetVarIdInternal(const Variable* var) const; + + int GetVarIdLocally(const Variable* var) const; + std::shared_ptr NewChild(Scope* scope); ValueExecutionInfo* parent_{nullptr}; // not owned diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 0059731809108..7542a558f21f6 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -229,9 +229,9 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in, const phi::KernelKey& kernel_key, pir::Block* block) { pir::IrContext* ctx = pir::IrContext::Instance(); - std::string op_name = paddle::dialect::PhiKernelOp::name(); - pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); + pir::OpInfo kernel_op_info = + ctx->GetRegisteredOpInfo(paddle::dialect::PhiKernelOp::name()); if ((src_place.GetType() == phi::AllocationType::CPU) && (dst_place.GetType() == phi::AllocationType::GPU)) { @@ -244,7 +244,7 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in, {"dst_place_type", pir::Int32Attribute::get(ctx, 1)}}; pir::Operation* op = - pir::Operation::Create({in}, op_attribute, {out_type}, op_info); + pir::Operation::Create({in}, op_attribute, {out_type}, kernel_op_info); if (in.owner()->HasAttribute(kAttrIsPersisable)) { op->set_attribute(kAttrIsPersisable, @@ -266,7 +266,7 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in, {"dst_place_type", pir::Int32Attribute::get(ctx, 0)}}; pir::Operation* op = - pir::Operation::Create({in}, op_attribute, {out_type}, op_info); + pir::Operation::Create({in}, op_attribute, {out_type}, kernel_op_info); block->push_back(op); @@ -665,45 +665,73 @@ void HandleForIfOp( pir::IrContext* ctx, std::unordered_map* map_op_pair, std::unordered_map* map_value_pair) { - auto cur_in = op_item->operand_source(0); - + auto old_cond = op_item->operand_source(0); PADDLE_ENFORCE_EQ( - map_value_pair->count(cur_in), + map_value_pair->count(old_cond), true, phi::errors::PreconditionNotMet( "[%d]'s input of [%s] op MUST in map pair", 0, op_item->name())); - auto new_in = map_value_pair->at(cur_in); + auto new_cond = map_value_pair->at(old_cond); + + // NOTE(zhangbo): IfOp's input cond should be a cpu type. + AllocatedDenseTensorType new_cond_type = + new_cond.type().dyn_cast(); + if (new_cond_type) { + if (new_cond_type.place().GetType() == phi::AllocationType::GPU) { + auto out_type = dialect::AllocatedDenseTensorType::get( + ctx, + phi::CPUPlace(), + old_cond.type().dyn_cast()); + phi::KernelKey kernel_key( + phi::Backend::GPU, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL); + new_cond = AddPlaceTransferOp(new_cond, + out_type, + new_cond_type.place(), + phi::CPUPlace(), + kernel_key, + block); + } + } else { + PADDLE_THROW( + phi::errors::Unimplemented("IfOp onlu support DenseTensorType")); + } + // Create IfOp and insert to kernel dialect program pir::Builder builder(ctx, block); - - auto base_if_op = op_item->dyn_cast(); - std::vector op_output_types; - for (size_t i = 0; i < base_if_op.num_results(); ++i) { - op_output_types.push_back(paddle::dialect::AllocatedDenseTensorType::get( + auto old_ifop = op_item->dyn_cast(); + std::vector new_ifop_outputs; + for (size_t i = 0; i < old_ifop.num_results(); ++i) { + new_ifop_outputs.push_back(paddle::dialect::AllocatedDenseTensorType::get( ctx, place, - base_if_op.result(i).type().dyn_cast())); + old_ifop.result(i).type().dyn_cast())); } - auto new_if_op = - builder.Build(new_in, std::move(op_output_types)); + auto new_ifop = builder.Build( + new_cond, std::move(new_ifop_outputs)); // process true block - pir::Block* true_block = new_if_op.true_block(); + pir::Block* true_block = new_ifop.true_block(); ProcessBlock(place, - base_if_op.true_block(), + old_ifop.true_block(), true_block, ctx, map_op_pair, map_value_pair); // process false block - pir::Block* false_block = new_if_op.false_block(); + pir::Block* false_block = new_ifop.false_block(); ProcessBlock(place, - base_if_op.false_block(), + old_ifop.false_block(), false_block, ctx, map_op_pair, map_value_pair); + + // update map + (*map_op_pair)[op_item] = new_ifop; + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = new_ifop->result(i); + } } pir::OpResult GetNewInput( From 29012fea2f511741a8a6b6b3f099cea8b6d5e1b9 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 27 Sep 2023 07:33:42 +0000 Subject: [PATCH 02/12] add --- .../instruction/cond_instruction.cc | 91 ++++++------------- .../instruction/cond_instruction.h | 2 +- .../instruction/instruction_base.cc | 32 ++----- .../instruction/instruction_base.h | 10 +- .../instruction/instruction_util.cc | 21 ++--- .../instruction/instruction_util.h | 11 +-- .../instruction/legacy_kernel_instruction.cc | 37 +++----- .../instruction/legacy_kernel_instruction.h | 17 ++-- .../instruction/phi_kernel_instruction.cc | 29 ++---- .../instruction/phi_kernel_instruction.h | 16 ++-- .../interpreter/interpreter_util.cc | 90 ------------------ .../interpreter/interpreter_util.h | 9 -- .../new_executor/new_ir_interpreter.cc | 18 +--- .../pir_adaptor/pir_adaptor_util.cc | 46 ++++------ .../pir_adaptor/pir_adaptor_util.h | 35 +++---- 15 files changed, 123 insertions(+), 341 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc index 5d958d7266505..76bd0e2c13782 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc @@ -51,29 +51,20 @@ std::vector GetYiedOpInputs(pir::Block* block) { return vec_res; } -void GetInputIds( - pir::Operation* op, - Scope* inner_scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name, - std::unordered_map>* input_ids) { +void GetInputIds(pir::Operation* op, + const ValueExecutionInfo& value_exec_info, + std::unordered_map>* input_ids) { for (size_t i = 0; i < op->num_operands(); i++) { pir::Value value = op->operand_source(i); if (value) { - PADDLE_ENFORCE_NE( - value_2_var_name.find(value), - value_2_var_name.end(), + PADDLE_ENFORCE_EQ( + value_exec_info.HasValue(value), + true, phi::errors::PreconditionNotMet( "input should in name map, [%d] 'th input of [%s] op", i, "if op")); - std::vector inputs_id = GetValueIds(value, - inner_scope, - value_2_var_name, - var_name_2_id, - variable_2_var_name); + std::vector inputs_id = GetValueIds(value, value_exec_info); input_ids->emplace(value, inputs_id); } } @@ -81,11 +72,7 @@ void GetInputIds( void GetOutsideOpInputs( pir::Block* block, - Scope* inner_scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name, + const ValueExecutionInfo& value_exec_info, std::unordered_map>* input_ids) { std::unordered_set inner_outputs; for (auto op : (*block)) { @@ -98,18 +85,14 @@ void GetOutsideOpInputs( for (size_t i = 0; i < op->num_operands(); ++i) { pir::Value value = op->operand_source(i); if (value && (!inner_outputs.count(value))) { - PADDLE_ENFORCE_NE( - value_2_var_name.find(value), - value_2_var_name.end(), + PADDLE_ENFORCE_EQ( + value_exec_info.HasValue(value), + true, phi::errors::PreconditionNotMet( "input should in name map, [%d] 'th input of [%s] op", i, - "if op")); - std::vector inputs_id = GetValueIds(value, - inner_scope, - value_2_var_name, - var_name_2_id, - variable_2_var_name); + op->name())); + std::vector inputs_id = GetValueIds(value, value_exec_info); input_ids->emplace(value, inputs_id); } @@ -123,7 +106,7 @@ CondInstruction::CondInstruction( pir::Operation* op, Scope* scope, Scope* local_scope, - ValueExecutionInfo* parent_exe_info, + ValueExecutionInfo* value_exec_info, const std::map& sub_blocks) : InstructionBase(id, place) { op_ = op; @@ -144,11 +127,11 @@ CondInstruction::CondInstruction( for (size_t i = 0; i < if_op.num_results(); ++i) { if_op_outputs_.push_back(inner_scope->GetVar( - parent_exe_info->GetValue2VarName().at(if_op.result(i)))); + value_exec_info->GetValue2VarName().at(if_op.result(i)))); } auto cond_value = if_op.operand_source(0); - auto var_name = parent_exe_info->GetValue2VarName().at(cond_value); + auto var_name = value_exec_info->GetValue2VarName().at(cond_value); cond_var = inner_scope->FindVar(var_name); auto true_branch_block = if_op.true_block(); @@ -163,7 +146,7 @@ CondInstruction::CondInstruction( {}, true_branch_block, true_scope, - parent_exe_info->NewChild(true_scope), + value_exec_info->NewChild(true_scope), {}); std::set true_skip_gc_names_set; @@ -179,7 +162,7 @@ CondInstruction::CondInstruction( {}, false_branch_block, false_scope, - parent_exe_info->NewChild(false_scope), + value_exec_info->NewChild(false_scope), {}); std::set false_skip_gc_names_set; @@ -192,44 +175,25 @@ CondInstruction::CondInstruction( // the true branch and false branch input will be the if_op inputs std::unordered_map> inputs; - GetInputIds(op, - inner_scope, - parent_exe_info->GetValue2VarName(), - parent_exe_info->GetVarName2Id(), - parent_exe_info->GetVar2VarName(), - &inputs); - GetOutsideOpInputs(true_branch_block, - inner_scope, - parent_exe_info->GetValue2VarName(), - parent_exe_info->GetVarName2Id(), - parent_exe_info->GetVar2VarName(), - &inputs); - - GetOutsideOpInputs(false_branch_block, - inner_scope, - parent_exe_info->GetValue2VarName(), - parent_exe_info->GetVarName2Id(), - parent_exe_info->GetVar2VarName(), - &inputs); + + GetInputIds(op, *value_exec_info, &inputs); + GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs); + GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs); + SetInputs(inputs); std::unordered_map> outputs; for (size_t i = 0; i < op->num_results(); i++) { pir::Value value = op->result(i); if (value && value.type()) { - PADDLE_ENFORCE_NE( - parent_exe_info->GetValue2VarName().find(value), - parent_exe_info->GetValue2VarName().end(), + PADDLE_ENFORCE_EQ( + value_exec_info->HasValue(value), + true, phi::errors::PreconditionNotMet( "input should in name map, [%d] 'th input of [%s] op", i, "if op")); - std::vector outputs_id = - GetValueIds(value, - inner_scope, - parent_exe_info->GetValue2VarName(), - parent_exe_info->GetVarName2Id(), - parent_exe_info->GetVar2VarName()); + std::vector outputs_id = GetValueIds(value, *value_exec_info); outputs.emplace(value, outputs_id); } } @@ -247,6 +211,7 @@ void CondInstruction::CopyBranchOutput( } void CondInstruction::Run() { + DeviceContext().Wait(); if (cond_var->Get().data()[0]) { true_branch_inter->Run({}, false); CopyBranchOutput(true_skip_gc_names_, true_branch_inter); diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h index 75eb7d0ece04f..5c3879c49905f 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h @@ -35,7 +35,7 @@ class CondInstruction : public InstructionBase { ::pir::Operation* op, Scope* scope, Scope* local_scope, - ValueExecutionInfo* parent_exe_info, + ValueExecutionInfo* value_exe_info, const std::map& sub_blocks); void Run() override; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc index a6d2f5a201b38..0b494c29dea86 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" @@ -214,12 +215,7 @@ void InstructionBase::SetOutputs( } void InstructionBase::InitInputsOutputsIds( - ::pir::Operation* op, - Scope* inner_scope, - const std::unordered_map& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name) { + ::pir::Operation* op, const ValueExecutionInfo& value_exec_info) { auto op_attributes = op->attributes(); auto op_name = op_attributes.at("op_name").dyn_cast().AsString(); @@ -227,18 +223,14 @@ void InstructionBase::InitInputsOutputsIds( for (size_t i = 0; i < op->num_operands(); i++) { pir::Value value = op->operand_source(i); if (value) { - PADDLE_ENFORCE_NE( - value_2_var_name.find(value), - value_2_var_name.end(), + PADDLE_ENFORCE_EQ( + value_exec_info.HasValue(value), + true, phi::errors::PreconditionNotMet( "input should in name map, [%d] 'th input of [%s] op", i, op_name)); - std::vector inputs_id = GetValueIds(value, - inner_scope, - value_2_var_name, - var_name_2_id, - variable_2_var_name); + std::vector inputs_id = GetValueIds(value, value_exec_info); inputs.emplace(value, inputs_id); } } @@ -248,18 +240,14 @@ void InstructionBase::InitInputsOutputsIds( for (size_t i = 0; i < op->num_results(); i++) { pir::Value value = op->result(i); if (value && value.type()) { - PADDLE_ENFORCE_NE( - value_2_var_name.find(value), - value_2_var_name.end(), + PADDLE_ENFORCE_EQ( + value_exec_info.HasValue(value), + true, phi::errors::PreconditionNotMet( "input should in name map, [%d] 'th input of [%s] op", i, op_name)); - std::vector outputs_id = GetValueIds(value, - inner_scope, - value_2_var_name, - var_name_2_id, - variable_2_var_name); + std::vector outputs_id = GetValueIds(value, value_exec_info); outputs.emplace(value, outputs_id); } } diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.h b/paddle/fluid/framework/new_executor/instruction/instruction_base.h index 7a77e8e8fae85..6079742611915 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.h @@ -28,6 +28,7 @@ class Value; namespace paddle { namespace framework { +class ValueExecutionInfo; using SchedulingPriority = int64_t; @@ -139,13 +140,8 @@ class InstructionBase { virtual ::pir::Operation* Operation() const = 0; - void InitInputsOutputsIds( - ::pir::Operation* op, - Scope* inner_scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name); + void InitInputsOutputsIds(::pir::Operation* op, + const ValueExecutionInfo& value_exec_info); // if scope is not null, also show dimensions of arguments virtual std::string DebugStringEx( diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index dfafd44281537..f888c0d514d51 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -40,22 +40,17 @@ PHI_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace framework { -std::vector GetValueIds( - pir::Value value, - Scope* inner_scope, - const std::unordered_map& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name) { +std::vector GetValueIds(pir::Value value, + const ValueExecutionInfo& value_exec_info) { std::vector ids; - auto& var_name = value_2_var_name.at(value); - ids.push_back(var_name_2_id.at(var_name)); + ids.push_back(value_exec_info.GetVarId(value)); // NOTE(zhangbo): Value maybe a VariableRefArray - auto var = inner_scope->FindVar(var_name); + auto var = + value_exec_info.GetScope()->FindVar(value_exec_info.GetVarName(value)); if (var->IsType()) { auto& var_array = var->Get(); for (auto item : var_array) { - ids.push_back(var_name_2_id.at(variable_2_var_name.at(item))); + ids.push_back(value_exec_info.GetVarId(item)); } } return ids; @@ -147,6 +142,10 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) { return OpFuncType::kCpuSync; } + if (op->dialect()->name() == "pd_op") { + return OpFuncType::kGpuAsync; + } + auto kernel_key = op->attributes() .at("kernel_key") .dyn_cast() diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.h b/paddle/fluid/framework/new_executor/instruction/instruction_util.h index c555a101d8366..dd1b98fa3dc15 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.h @@ -28,13 +28,10 @@ namespace paddle { namespace framework { -std::vector GetValueIds( - pir::Value value, - Scope* inner_scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name); +class ValueExecutionInfo; + +std::vector GetValueIds(pir::Value value, + const ValueExecutionInfo& value_exec_info); platform::DeviceContext* ParseDeviceContext( pir::Operation* op, diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc index 748c7e603f7d7..17479db817e68 100644 --- a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc @@ -36,13 +36,8 @@ LegacyKernelInstruction::LegacyKernelInstruction( size_t id, const platform::Place& place, pir::Operation* op, - Scope* scope, - Scope* local_scope, - const std::unordered_map& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name) - : InstructionBase(id, place) { + const ValueExecutionInfo& value_exec_info) + : InstructionBase(id, place), value_exec_info_(value_exec_info) { auto& op_attributes = op->attributes(); auto op_name = op_attributes.at("op_name").dyn_cast().AsString(); @@ -105,12 +100,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( phi::MetaTensor, paddle::small_vector, paddle::small_vector, - false>(op, - value_2_var_name, - scope, - local_scope, - yaml_info_parser, - &infer_meta_context_); + false>(op, value_exec_info_, yaml_info_parser, &infer_meta_context_); } VLOG(6) << "finish process infer meta context"; @@ -126,10 +116,11 @@ LegacyKernelInstruction::LegacyKernelInstruction( phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name); VLOG(6) << "finish process select kernel: " << kernel_name; - Scope* inner_scope = local_scope == nullptr ? scope : local_scope; + const Scope* inner_scope = value_exec_info_.GetScope(); + + operator_base_ = + pir::BuildOperatorBase(op, value_exec_info_, yaml_info_parser); - operator_base_ = pir::BuildOperatorBase( - op, value_2_var_name, yaml_info_parser, variable_2_var_name, inner_scope); paddle::framework::VariableValueMap in_map; paddle::framework::VariableValueMap out_map; auto dev_ctx = phi::DeviceContextPool::Instance().Get( @@ -137,14 +128,11 @@ LegacyKernelInstruction::LegacyKernelInstruction( runtime_context_ = std::make_shared( paddle::framework::RuntimeContext(in_map, out_map)); - pir::BuildRuntimeContext(op, - value_2_var_name, - scope, - local_scope, - yaml_info_parser, - runtime_context_.get()); + pir::BuildRuntimeContext( + op, value_exec_info, yaml_info_parser, runtime_context_.get()); + kernel_context_ = new paddle::framework::ExecutionContext( - *operator_base_, *local_scope, *dev_ctx, *(runtime_context_.get())); + *operator_base_, *inner_scope, *dev_ctx, *(runtime_context_.get())); VLOG(6) << "finish process kernel context"; SetDeviceContext( @@ -156,8 +144,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( GetStreamPriority())); VLOG(6) << "finish process device context"; - InitInputsOutputsIds( - op, inner_scope, value_2_var_name, var_name_2_id, variable_2_var_name); + InitInputsOutputsIds(op, value_exec_info); VLOG(6) << "finish process inputs outputs index"; auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h index 9c6fbd9b7d807..0f122ae9565a1 100644 --- a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h @@ -24,19 +24,14 @@ class Value; namespace paddle { namespace framework { class Scope; +class ValueExecutionInfo; class LegacyKernelInstruction : public InstructionBase { public: - LegacyKernelInstruction( - size_t id, - const platform::Place& place, - ::pir::Operation* op, - Scope* scope, - Scope* local_scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name); + LegacyKernelInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo& value_exec_info); ~LegacyKernelInstruction(); phi::Kernel* PhiKernel() const { return phi_kernel_; } @@ -70,6 +65,8 @@ class LegacyKernelInstruction : public InstructionBase { phi::Kernel* phi_kernel_{nullptr}; // not owned ::pir::Operation* op_{nullptr}; // not owned + + const ValueExecutionInfo& value_exec_info_; // not owned }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index e779fb52f26e4..7b5e5afb0fac6 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -40,13 +40,8 @@ PhiKernelInstruction::PhiKernelInstruction( size_t id, const platform::Place& place, pir::Operation* op, - Scope* scope, - Scope* local_scope, - const std::unordered_map& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name) - : InstructionBase(id, place) { + const ValueExecutionInfo& value_exec_info) + : InstructionBase(id, place), value_exec_info_(value_exec_info) { auto op_attributes = op->attributes(); auto op_name = op_attributes.at("op_name").dyn_cast().AsString(); @@ -109,12 +104,7 @@ PhiKernelInstruction::PhiKernelInstruction( phi::MetaTensor, paddle::small_vector, paddle::small_vector, - false>(op, - value_2_var_name, - scope, - local_scope, - yaml_info_parser, - &infer_meta_context_); + false>(op, value_exec_info_, yaml_info_parser, &infer_meta_context_); } VLOG(6) << "finish process infer meta context"; @@ -135,12 +125,9 @@ PhiKernelInstruction::PhiKernelInstruction( phi::TensorBase*, paddle::small_vector, paddle::small_vector, - true>(op, - value_2_var_name, - scope, - local_scope, - yaml_info_parser, - &kernel_context_); + true>( + op, value_exec_info_, yaml_info_parser, &kernel_context_); + kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get( phi::TransToPhiPlace(kernel_key.backend()))); VLOG(6) << "finish process kernel context"; @@ -154,9 +141,7 @@ PhiKernelInstruction::PhiKernelInstruction( GetStreamPriority())); VLOG(6) << "finish process device context"; - Scope* inner_scope = local_scope == nullptr ? scope : local_scope; - InitInputsOutputsIds( - op, inner_scope, value_2_var_name, var_name_2_id, variable_2_var_name); + InitInputsOutputsIds(op, value_exec_info); VLOG(6) << "finish process inputs outputs index"; auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h index 96484f435a9f7..31018c8295625 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h @@ -27,16 +27,10 @@ class Value; class PhiKernelInstruction : public InstructionBase { public: - PhiKernelInstruction( - size_t id, - const platform::Place& place, - ::pir::Operation* op, - Scope* scope, - Scope* local_scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name); + PhiKernelInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo& value_exec_info); ~PhiKernelInstruction(); @@ -71,6 +65,8 @@ class PhiKernelInstruction : public InstructionBase { std::string phi_op_name_; ::pir::Operation* op_{nullptr}; // not owned + + const ValueExecutionInfo& value_exec_info_; // not owned }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 7002e2e787680..e6b47b9d234c9 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -1029,96 +1029,6 @@ void BuildOpFuncList(const platform::Place& place, delete garbages; } -void BuildOpFuncList( - const platform::Place& place, - pir::Block* block, - std::vector* vec_func_list, - framework::Scope* scope, - framework::Scope* local_scope, - const std::unordered_map& value_2_name_map, - const ExecutionConfig& execution_config) { - vec_func_list->reserve(block->size()); - pir::IrContext* ctx = pir::IrContext::Instance(); - - ctx->GetOrRegisterDialect(); - - for (auto op : *block) { - OpFuncNode op_func_node; - auto attr_map = op->attributes(); - - auto op_name = - attr_map.at("op_name").dyn_cast().AsString(); - op_func_node.phi_op_name_ = op_name; - - if (GetSpecialOpNames().count(op_name)) { - VLOG(6) << "skip process " << op_name; - continue; - } - - pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); - - auto impl = - op_info.GetInterfaceImpl(); - - op_func_node.infer_meta_interface_ = - op_info.GetInterfaceImpl(); - - VLOG(6) << "op name" << op_func_node.phi_op_name_; - dialect::OpYamlInfoParser op_yaml_info_parser(impl->get_op_info_()); - if (op_func_node.infer_meta_interface_) { - pir::BuildPhiContext< - phi::InferMetaContext, - phi::MetaTensor, - phi::MetaTensor, - paddle::small_vector, - paddle::small_vector, - false>(op, - value_2_name_map, - scope, - local_scope, - op_yaml_info_parser, - &(op_func_node.infer_meta_context_)); - } - - auto kernel_name = - attr_map.at("kernel_name").dyn_cast().AsString(); - auto kernel_key = attr_map.at("kernel_key") - .dyn_cast() - .data(); - - VLOG(6) << "finish process infer meta context"; - auto t1 = phi::KernelFactory::Instance().SelectKernelOrThrowError( - kernel_name, kernel_key); - op_func_node.phi_kernel_ = new phi::Kernel(t1.kernel); - - PADDLE_ENFORCE_EQ(op_func_node.phi_kernel_->IsValid(), - true, - "not found kernel for [%s]", - kernel_name); - - pir::BuildPhiContext, - paddle::small_vector, - true>(op, - value_2_name_map, - scope, - local_scope, - op_yaml_info_parser, - &(op_func_node.kernel_context_)); - - VLOG(6) << "finish process kernel context"; - op_func_node.kernel_context_.SetDeviceContext( - phi::DeviceContextPool::Instance().Get( - phi::TransToPhiPlace(kernel_key.backend()))); - op_func_node.dev_ctx_ = phi::DeviceContextPool::Instance().Get( - phi::TransToPhiPlace(kernel_key.backend())); - - vec_func_list->emplace_back(op_func_node); - } -} - void BuildVariableScope(const framework::BlockDesc& block, const ExecutionConfig& execution_config, VariableScope* var_scope) { diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index 213804ec980f6..df6113d5a57a2 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -104,15 +104,6 @@ void BuildOpFuncList(const platform::Place& place, bool use_local_scope = true, bool static_build = false); -void BuildOpFuncList( - const platform::Place& place, - ::pir::Block* block, - std::vector* vec_func_list, - framework::Scope* scope, - framework::Scope* local_scope, - const std::unordered_map<::pir::Value, std::string>& value_2_name_map, - const ExecutionConfig& execution_config); - void BuildVariableScope(const framework::BlockDesc& block, const ExecutionConfig& execution_config, VariableScope* var_scope); diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index ee2c4c3ea62ed..0c04f6041a755 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -586,25 +586,11 @@ void NewIRInterpreter::BuildInstruction() { if (op->name().compare(paddle::dialect::LegacyKernelOp::name()) == 0) { vec_instruction_base_.emplace_back( std::make_unique( - op_idx++, - place_, - op, - scope_, - local_scope_, - value_exe_info_->GetValue2VarName(), - value_exe_info_->GetVarName2Id(), - value_exe_info_->GetVar2VarName())); + op_idx++, place_, op, *(value_exe_info_.get()))); } else { vec_instruction_base_.emplace_back( std::make_unique( - op_idx++, - place_, - op, - scope_, - local_scope_, - value_exe_info_->GetValue2VarName(), - value_exe_info_->GetVarName2Id(), - value_exe_info_->GetVar2VarName())); + op_idx++, place_, op, *(value_exe_info_.get()))); } #ifdef PADDLE_WITH_CINN } else if (op->dialect()->name() == "cinn_runtime") { diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index f2aa098d40eab..16ccc16acf41c 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -331,18 +331,17 @@ paddle::framework::Variable* CreateVar( return var; } -void CheckInputVars( - pir::Operation* op, - const std::string& op_name, - const std::unordered_map& value_2_var_name) { +void CheckInputVars(pir::Operation* op, + const std::string& op_name, + paddle::framework::ValueExecutionInfo* execution_info) { size_t input_num = op->num_operands(); if (input_num > 0) { for (size_t i = 0; i < input_num; ++i) { auto value = op->operand_source(i); if (IsInvalid(value)) { PADDLE_ENFORCE_NE( - value_2_var_name.find(value), - value_2_var_name.end(), + execution_info->HasValue(value), + true, phi::errors::PreconditionNotMet( "input should in name map, [%d] 'th input of [%s] op", i, @@ -664,7 +663,7 @@ void BuildScope(const pir::Block& block, continue; } - CheckInputVars(op, op_name, value_exe_info->GetValue2VarName()); + CheckInputVars(op, op_name, value_exe_info); if (op->num_results() < 1) continue; if (op->attributes().count("is_inplace") != 0 && @@ -690,15 +689,11 @@ void BuildScope(const pir::Block& block, void BuildRuntimeContext( pir::Operation* op, - const std::unordered_map& name_map, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, + const paddle::framework::ValueExecutionInfo& value_exec_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info, paddle::framework::RuntimeContext* runtime_ctx) { - paddle::framework::Scope* inner_scope = - local_scope != nullptr ? local_scope : scope; - VLOG(6) << "BuildPhiContext in scope[" << scope << "] inner_scope[" - << inner_scope << "]"; + const paddle::framework::Scope* inner_scope = value_exec_info.GetScope(); + VLOG(6) << "BuildPhiContext in scope[" << inner_scope << "]"; auto& vec_kernel_fn_tensor_params = op_yaml_info.TensorParams(true); @@ -721,7 +716,7 @@ void BuildRuntimeContext( } auto legacy_attr_name = op_normalizer.GetLegacyArgName(fluid_op_name, name); - auto in_var_name = name_map.at(ptr); + auto in_var_name = value_exec_info.GetVarName(ptr); VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name; PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), phi::errors::PreconditionNotMet( @@ -740,7 +735,7 @@ void BuildRuntimeContext( continue; } - auto in_var_name = name_map.at(ptr); + auto in_var_name = value_exec_info.GetVarName(ptr); VLOG(6) << "ctx->EmplaceBackOutput: " << name << "\t" << in_var_name; PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), @@ -771,11 +766,8 @@ void BuildRuntimeContext( std::shared_ptr BuildOperatorBase( pir::Operation* op, - const std::unordered_map& name_map, - const paddle::dialect::OpYamlInfoParser& op_yaml_info, - const std::unordered_map& - variable_2_var_name, - const paddle::framework::Scope* scope) { + const paddle::framework::ValueExecutionInfo& value_exec_info, + const paddle::dialect::OpYamlInfoParser& op_yaml_info) { paddle::framework::VariableNameMap in_name_map; paddle::framework::VariableNameMap out_name_map; paddle::framework::AttributeMap attr_map; @@ -787,6 +779,8 @@ std::shared_ptr BuildOperatorBase( auto& op_normalizer = paddle::translator::OpNameNormalizer::instance(); + auto scope = value_exec_info.GetScope(); + // build inputs for (auto& name : vec_kernel_fn_tensor_params) { PADDLE_ENFORCE_EQ( @@ -801,7 +795,7 @@ std::shared_ptr BuildOperatorBase( continue; } - in_name_map[legacy_attr_name].push_back(name_map.at(ptr)); + in_name_map[legacy_attr_name].push_back(value_exec_info.GetVarName(ptr)); } // build attribute @@ -890,15 +884,13 @@ std::shared_ptr BuildOperatorBase( if (ptr.type().isa() || ptr.type().isa()) { - out_name_map[legacy_arg_name].push_back(name_map.at(ptr)); + out_name_map[legacy_arg_name].push_back(value_exec_info.GetVarName(ptr)); } else if (ptr.type().isa()) { - auto var = scope->FindVar(name_map.at(ptr)); + auto var = scope->FindVar(value_exec_info.GetVarName(ptr)); auto var_ref = var->Get(); for (size_t k = 0; k < var_ref.size(); ++k) { - PADDLE_ENFORCE(variable_2_var_name.count(var_ref[k]), - "Variable MUST in variable_2_var_name map"); out_name_map[legacy_arg_name].push_back( - variable_2_var_name.at(var_ref[k])); + value_exec_info.GetVarName(var_ref[k])); } } else { PADDLE_THROW(phi::errors::Unimplemented( diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index 82f93a417a32f..0cc7f558cd1d4 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -164,19 +164,14 @@ void BuildScope( void BuildRuntimeContext( pir::Operation* op, - const std::unordered_map& name_map, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, + const paddle::framework::ValueExecutionInfo& value_exec_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info, paddle::framework::RuntimeContext* runtime_ctx); std::shared_ptr BuildOperatorBase( pir::Operation* op, - const std::unordered_map& name_map, - const paddle::dialect::OpYamlInfoParser& op_yaml_info, - const std::unordered_map& - variable_2_var_name, - const paddle::framework::Scope* scope); + const paddle::framework::ValueExecutionInfo& value_exec_info, + const paddle::dialect::OpYamlInfoParser& op_yaml_info); template void BuildPhiContext( pir::Operation* op, - const std::unordered_map& name_map, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, + const paddle::framework::ValueExecutionInfo& value_exec_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info, Context* ctx) { - paddle::framework::Scope* inner_scope = - local_scope != nullptr ? local_scope : scope; - VLOG(6) << "Build " << get_type_name() << " in scope[" << scope - << "] inner_scope[" << inner_scope << "]"; + paddle::framework::Scope* inner_scope = value_exec_info.GetScope(); + VLOG(6) << "Build " << get_type_name() << "] inner_scope[" + << inner_scope << "]"; auto attr_map = op->attributes(); @@ -223,7 +215,7 @@ void BuildPhiContext( continue; } - auto in_var_name = name_map.at(ptr); + auto in_var_name = value_exec_info.GetVarName(ptr); VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), @@ -264,7 +256,7 @@ void BuildPhiContext( // tensor attribute, get information from input pir::Value ptr = op->operand_source(name2id.at(t)); - auto in_var_name = name_map.at(ptr); + auto in_var_name = value_exec_info.GetVarName(ptr); auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t); VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name; @@ -438,17 +430,18 @@ void BuildPhiContext( if (out_ptr.type().isa()) { ctx->EmplaceBackOutput(OutType(const_cast( - &(inner_scope->FindVar(name_map.at(out_ptr)) + &(inner_scope->FindVar(value_exec_info.GetVarName(out_ptr)) ->Get())))); } else if (out_ptr.type() .isa()) { ctx->EmplaceBackOutput(OutType(const_cast( - &(inner_scope->FindVar(name_map.at(out_ptr)) + &(inner_scope->FindVar(value_exec_info.GetVarName(out_ptr)) ->Get())))); } else if (out_ptr.type().isa()) { OutListType outputs; - auto& variable_array = inner_scope->FindVar(name_map.at(out_ptr)) - ->Get(); + auto& variable_array = + inner_scope->FindVar(value_exec_info.GetVarName(out_ptr)) + ->Get(); for (size_t i = 0; i < variable_array.size(); ++i) { if (variable_array[i]->IsType()) { outputs.emplace_back(OutType(const_cast( From a2a0567898728f7753b93c0b2bf2a9ef8d1fad55 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 27 Sep 2023 08:07:50 +0000 Subject: [PATCH 03/12] fix --- paddle/fluid/framework/new_executor/instruction/CMakeLists.txt | 2 +- .../new_executor/instruction/legacy_kernel_instruction.h | 1 - .../framework/new_executor/instruction/phi_kernel_instruction.h | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt index 0623499975b6f..64c14374162c6 100644 --- a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt @@ -2,7 +2,7 @@ cc_library( instruction_base SRCS instruction_base.cc phi_kernel_instruction.cc legacy_kernel_instruction.cc cond_instruction.cc instruction_util.cc - DEPS phi framework_proto) + DEPS pir_adaptor phi framework_proto) if(WITH_CINN AND NOT CINN_ONLY) cc_library( diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h index 0f122ae9565a1..1ccbc8ebc0158 100644 --- a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h @@ -18,7 +18,6 @@ namespace pir { class Operation; -class Value; } // namespace pir namespace paddle { diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h index 31018c8295625..41539300c4503 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h @@ -23,7 +23,7 @@ class Operation; namespace paddle { namespace framework { class Scope; -class Value; +class ValueExecutionInfo; class PhiKernelInstruction : public InstructionBase { public: From d78c9865c922d34637e7659f04012fc947a1ec11 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 27 Sep 2023 08:53:18 +0000 Subject: [PATCH 04/12] fix --- .../framework/new_executor/pir_adaptor/pir_adaptor_util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index bd8e7440863b4..bc172fe4f927e 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -339,7 +339,7 @@ void CheckInputVars(pir::Operation* op, for (size_t i = 0; i < input_num; ++i) { auto value = op->operand_source(i); if (IsInvalid(value)) { - PADDLE_ENFORCE_NE( + PADDLE_ENFORCE_EQ( execution_info->HasValue(value), true, phi::errors::PreconditionNotMet( From 838185c699c0c2a1618c1510c0d0d903c84542d8 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 27 Sep 2023 10:57:28 +0000 Subject: [PATCH 05/12] refine --- .../instruction/legacy_kernel_instruction.cc | 7 +- .../instruction/phi_kernel_instruction.cc | 14 +- .../new_executor/new_ir_interpreter.cc | 4 +- .../pir_adaptor/pir_adaptor_util.cc | 144 ++++++++---------- .../pir_adaptor/pir_adaptor_util.h | 53 +++---- 5 files changed, 98 insertions(+), 124 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc index 17479db817e68..97bda34777008 100644 --- a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc @@ -94,7 +94,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( VLOG(6) << "finish process yaml_info_parser"; if (infer_meta_interface_) { - pir::BuildPhiContext< + BuildPhiContext< phi::InferMetaContext, phi::MetaTensor, phi::MetaTensor, @@ -118,8 +118,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( const Scope* inner_scope = value_exec_info_.GetScope(); - operator_base_ = - pir::BuildOperatorBase(op, value_exec_info_, yaml_info_parser); + operator_base_ = BuildOperatorBase(op, value_exec_info_, yaml_info_parser); paddle::framework::VariableValueMap in_map; paddle::framework::VariableValueMap out_map; @@ -128,7 +127,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( runtime_context_ = std::make_shared( paddle::framework::RuntimeContext(in_map, out_map)); - pir::BuildRuntimeContext( + BuildRuntimeContext( op, value_exec_info, yaml_info_parser, runtime_context_.get()); kernel_context_ = new paddle::framework::ExecutionContext( diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index 7b5e5afb0fac6..3f93161a363fa 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -98,7 +98,7 @@ PhiKernelInstruction::PhiKernelInstruction( VLOG(6) << "finish process yaml_info_parser"; if (infer_meta_interface_) { - pir::BuildPhiContext< + BuildPhiContext< phi::InferMetaContext, phi::MetaTensor, phi::MetaTensor, @@ -120,12 +120,12 @@ PhiKernelInstruction::PhiKernelInstruction( phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name); VLOG(6) << "finish process select kernel"; - pir::BuildPhiContext, - paddle::small_vector, - true>( + BuildPhiContext, + paddle::small_vector, + true>( op, value_exec_info_, yaml_info_parser, &kernel_context_); kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get( diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 0c04f6041a755..4d133135d4c17 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -114,7 +114,7 @@ NewIRInterpreter::NewIRInterpreter( std::stringstream ss; ss << this; - ::pir::BuildScope(*ir_block_, ss.str(), &sub_blocks_, value_exe_info_.get()); + BuildScope(*ir_block_, ss.str(), &sub_blocks_, value_exe_info_.get()); } NewIRInterpreter::NewIRInterpreter( @@ -176,7 +176,7 @@ NewIRInterpreter::NewIRInterpreter( std::stringstream ss; ss << this; - ::pir::BuildScope(*ir_block_, ss.str(), &sub_blocks_, value_exe_info_.get()); + BuildScope(*ir_block_, ss.str(), &sub_blocks_, value_exe_info_.get()); } NewIRInterpreter::~NewIRInterpreter() { diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index bc172fe4f927e..d0abdf02e2a83 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -13,37 +13,34 @@ // limitations under the License. #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" -#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" -#include "paddle/fluid/pir/dialect/operator/utils/utils.h" -#include "paddle/phi/core/meta_tensor.h" -#include "paddle/pir/core/builtin_attribute.h" -#include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/ir_context.h" -#include "paddle/pir/core/program.h" -#include "paddle/pir/core/utils.h" +#include "glog/logging.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_ref_array.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable_helper.h" -#include "paddle/phi/core/kernel_context.h" - -#include "paddle/fluid/framework/string_array.h" -#include "paddle/fluid/framework/tensor_ref_array.h" #include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/enforce.h" - -#include "glog/logging.h" -#include "paddle/fluid/framework/op_info.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/phi/core/kernel_context.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/utils.h" namespace paddle { namespace framework { @@ -136,7 +133,7 @@ void ValueExecutionInfo::AddValue2VarName(::pir::Value value, value_2_var_name_.emplace(value, var_name); } -const std::unordered_map& +const std::unordered_map& ValueExecutionInfo::GetVar2VarName() const { return var_2_var_name_; } @@ -279,11 +276,6 @@ int ValueExecutionInfo::GetVarIdLocally(const Variable* var) const { return -1; } -} // namespace framework -} // namespace paddle - -namespace pir { - const std::unordered_set SpecialOps = {"pd_op.feed", "pd_op.fetch", "builtin.combine", @@ -295,21 +287,17 @@ const std::unordered_set SpecialOps = {"pd_op.feed", "pd_op.shadow_output", "pd_op.if"}; -using VariableNameMap = - std::unordered_map; - -paddle::framework::Variable* CreateVar( - pir::Value value, - const std::string& var_name_prefix, - bool force_persisable, - paddle::framework::ValueExecutionInfo* value_exe_info) { - Operation* def_op = value.dyn_cast().owner(); +Variable* CreateVar(pir::Value value, + const std::string& var_name_prefix, + bool force_persisable, + ValueExecutionInfo* value_exe_info) { + pir::Operation* def_op = value.dyn_cast().owner(); bool is_persisable = false; if (def_op->isa<::pir::SetParameterOp>()) { is_persisable = true; } - paddle::framework::Variable* var = nullptr; + Variable* var = nullptr; std::string name = var_name_prefix + "_inner_var_" + std::to_string(value_exe_info->GetVar2VarName().size()); @@ -317,9 +305,7 @@ paddle::framework::Variable* CreateVar( if (force_persisable || is_persisable) { VLOG(6) << "Create var: " << name << " in scope " << value_exe_info->GetScope()->root(); - var = const_cast( - value_exe_info->GetScope()->root()) - ->Var(name); + var = const_cast(value_exe_info->GetScope()->root())->Var(name); } else { VLOG(6) << "Create var: " << name << " in scope " << value_exe_info->GetScope(); @@ -333,7 +319,7 @@ paddle::framework::Variable* CreateVar( void CheckInputVars(pir::Operation* op, const std::string& op_name, - paddle::framework::ValueExecutionInfo* execution_info) { + ValueExecutionInfo* execution_info) { size_t input_num = op->num_operands(); if (input_num > 0) { for (size_t i = 0; i < input_num; ++i) { @@ -353,13 +339,13 @@ void CheckInputVars(pir::Operation* op, void BuildValue(pir::Value value, const std::string& var_name_prefix, - paddle::framework::ValueExecutionInfo* value_exe_info) { + ValueExecutionInfo* value_exe_info) { if (!IsInvalid(value)) { VLOG(8) << "Value is not invalid, so skip build a variable."; return; } - paddle::framework::Variable* var = nullptr; + Variable* var = nullptr; auto& value_2_var_name = value_exe_info->GetValue2VarName(); if (value_2_var_name.find(value) != value_2_var_name.end()) { var = value_exe_info->GetScope()->FindVar(value_2_var_name.at(value)); @@ -373,7 +359,7 @@ void BuildValue(pir::Value value, } else if (value.type().isa()) { var->GetMutable(); } else if (value.type().isa()) { - auto tensor_array = var->GetMutable(); + auto tensor_array = var->GetMutable(); for (size_t i = 0; i < value.type().dyn_cast().size(); i++) { PADDLE_ENFORCE(value.type() @@ -393,11 +379,10 @@ void BuildValue(pir::Value value, } } -void HandleForSpecialOp( - pir::Operation* op, - const std::string& var_name_prefix, - std::map* sub_blocks, - paddle::framework::ValueExecutionInfo* value_exe_info) { +void HandleForSpecialOp(pir::Operation* op, + const std::string& var_name_prefix, + std::map* sub_blocks, + ValueExecutionInfo* value_exe_info) { std::string op_name = op->name(); if (op->attributes().count("op_name")) { op_name = @@ -410,8 +395,7 @@ void HandleForSpecialOp( op->attributes().at("name").dyn_cast().AsString(); auto fetch_var_name = fetch_src_name + "@fetch"; - auto* var = const_cast( - value_exe_info->GetScope()->root()) + auto* var = const_cast(value_exe_info->GetScope()->root()) ->Var(fetch_var_name); var->GetMutable(); auto value = op->result(0); @@ -427,7 +411,7 @@ void HandleForSpecialOp( std::string name = op->attributes().at("name").dyn_cast().AsString(); - paddle::framework::Variable* var = value_exe_info->GetScope()->Var(name); + Variable* var = value_exe_info->GetScope()->Var(name); PADDLE_ENFORCE(var, paddle::platform::errors::InvalidArgument( "The variable %s shoud exist", name)); @@ -438,7 +422,7 @@ void HandleForSpecialOp( if (op_name == "builtin.combine") { auto out_value = op->result(0); - paddle::framework::Variable* var = nullptr; + Variable* var = nullptr; auto& value_2_var_name = value_exe_info->GetValue2VarName(); if (value_2_var_name.find(out_value) != value_2_var_name.end()) { var = value_exe_info->GetScope()->FindVar(value_2_var_name.at(out_value)); @@ -446,7 +430,7 @@ void HandleForSpecialOp( var = CreateVar(out_value, var_name_prefix, false, value_exe_info); } - auto tensor_array = var->GetMutable(); + auto tensor_array = var->GetMutable(); // clear tensor array tensor_array->clear(); size_t input_num = op->num_operands(); @@ -479,7 +463,7 @@ void HandleForSpecialOp( "SetParamer param name should not equal with var name")); if (value_exe_info->GetScope()->root()->FindVar(param_name) == nullptr) { - const_cast(value_exe_info->GetScope()->root()) + const_cast(value_exe_info->GetScope()->root()) ->Rename(orig_name, param_name); VLOG(6) << "set_parameter rename var: " << orig_name << " -> " << param_name; @@ -498,7 +482,7 @@ void HandleForSpecialOp( auto orig_name = value_exe_info->GetValue2VarName().at(value); if (value_exe_info->GetScope()->root()->FindVar(var_name) == nullptr) { - const_cast(value_exe_info->GetScope()->root()) + const_cast(value_exe_info->GetScope()->root()) ->Rename(orig_name, var_name); } @@ -529,7 +513,7 @@ void HandleForSpecialOp( op->attributes().at("index").dyn_cast().data(); auto in_var = value_exe_info->GetScope()->FindVar( value_exe_info->GetValue2VarName().at(in_value)); - auto variable_array = in_var->Get(); + auto variable_array = in_var->Get(); PADDLE_ENFORCE_EQ( value_exe_info->GetVar2VarName().count(variable_array[index]), @@ -553,7 +537,7 @@ void HandleForSpecialOp( auto in_var = value_exe_info->GetScope()->FindVar( value_exe_info->GetValue2VarName().at(in_value)); - auto variable_array = in_var->Get(); + auto variable_array = in_var->Get(); for (uint64_t idx = 0; idx < variable_array.size(); ++idx) { auto out_value = op->result(idx); @@ -593,7 +577,7 @@ void HandleForSpecialOp( void HandleForInplaceOp(pir::Operation* op, const std::string& var_name_prefix, - paddle::framework::ValueExecutionInfo* value_exe_info) { + ValueExecutionInfo* value_exe_info) { if (op->num_results() < 1) return; pir::IrContext* ctx = pir::IrContext::Instance(); std::string op_name = op->name(); @@ -644,13 +628,12 @@ void HandleForInplaceOp(pir::Operation* op, // is created in inner_scope. void BuildScope(const pir::Block& block, const std::string& var_name_prefix, - std::map* sub_blocks, - paddle::framework::ValueExecutionInfo* value_exe_info) { + std::map* sub_blocks, + ValueExecutionInfo* value_exe_info) { VLOG(4) << "***** [before build] scope" << "(" << value_exe_info->GetScope() << ") ******\n" - << paddle::framework::GenScopeTreeDebugInfo( - const_cast( - value_exe_info->GetScope()->root())); + << GenScopeTreeDebugInfo( + const_cast(value_exe_info->GetScope()->root())); for (auto op : block) { std::string op_name = op->name(); @@ -685,17 +668,15 @@ void BuildScope(const pir::Block& block, VLOG(4) << "***** [after build] scope" << "(" << value_exe_info->GetScope() << ") ******\n" - << paddle::framework::GenScopeTreeDebugInfo( - const_cast( - value_exe_info->GetScope()->root())); + << GenScopeTreeDebugInfo( + const_cast(value_exe_info->GetScope()->root())); } -void BuildRuntimeContext( - pir::Operation* op, - const paddle::framework::ValueExecutionInfo& value_exec_info, - const paddle::dialect::OpYamlInfoParser& op_yaml_info, - paddle::framework::RuntimeContext* runtime_ctx) { - const paddle::framework::Scope* inner_scope = value_exec_info.GetScope(); +void BuildRuntimeContext(pir::Operation* op, + const ValueExecutionInfo& value_exec_info, + const paddle::dialect::OpYamlInfoParser& op_yaml_info, + RuntimeContext* runtime_ctx) { + const Scope* inner_scope = value_exec_info.GetScope(); VLOG(6) << "BuildPhiContext in scope[" << inner_scope << "]"; auto& vec_kernel_fn_tensor_params = op_yaml_info.TensorParams(true); @@ -754,11 +735,11 @@ void BuildRuntimeContext( type.isa()) { runtime_ctx->outputs[legacy_arg_name] = {var}; } else if (type.isa()) { - auto var_ref = var->Get(); - std::vector vec_tmp; + auto var_ref = var->Get(); + std::vector vec_tmp; vec_tmp.reserve(var_ref.size()); for (size_t k = 0; k < var_ref.size(); ++k) { - vec_tmp.push_back(const_cast(var_ref[k])); + vec_tmp.push_back(const_cast(var_ref[k])); } runtime_ctx->outputs[legacy_arg_name] = vec_tmp; } else { @@ -769,9 +750,9 @@ void BuildRuntimeContext( } } -std::shared_ptr BuildOperatorBase( +std::shared_ptr BuildOperatorBase( pir::Operation* op, - const paddle::framework::ValueExecutionInfo& value_exec_info, + const ValueExecutionInfo& value_exec_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info) { paddle::framework::VariableNameMap in_name_map; paddle::framework::VariableNameMap out_name_map; @@ -899,7 +880,7 @@ std::shared_ptr BuildOperatorBase( << value_exec_info.GetVarName(ptr); } else if (ptr.type().isa()) { auto var = scope->FindVar(value_exec_info.GetVarName(ptr)); - auto var_ref = var->Get(); + auto var_ref = var->Get(); for (size_t k = 0; k < var_ref.size(); ++k) { out_name_map[legacy_arg_name].push_back( value_exec_info.GetVarName(var_ref[k])); @@ -913,12 +894,13 @@ std::shared_ptr BuildOperatorBase( } } - auto& op_info = paddle::framework::OpInfoMap::Instance().Get(fluid_op_name); + auto& op_info = OpInfoMap::Instance().Get(fluid_op_name); auto ptr = op_info.Creator()(fluid_op_name, in_name_map, out_name_map, attr_map); - std::shared_ptr res(ptr); + std::shared_ptr res(ptr); return res; } -} // namespace pir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index 4ced634d07051..2617dcd45e5f4 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -140,11 +140,6 @@ class ValueExecutionInfo { std::vector var_list_; }; -} // namespace framework -} // namespace paddle - -namespace pir { - // NOTE(zhangbo): Some operators of Paddle support optional inputs or outputs, // representing whether the input or output exists. In the Pir, whether the // value itself is empty or the type it holds is empty is used to indicate @@ -156,21 +151,19 @@ inline bool IsInvalid(pir::Value value) { return true; } -void BuildScope( - const pir::Block& block, - const std::string& var_name_prefix, - std::map* sub_blocks, - paddle::framework::ValueExecutionInfo* value_exe_info = nullptr); +void BuildScope(const pir::Block& block, + const std::string& var_name_prefix, + std::map* sub_blocks, + ValueExecutionInfo* value_exe_info = nullptr); -void BuildRuntimeContext( - pir::Operation* op, - const paddle::framework::ValueExecutionInfo& value_exec_info, - const paddle::dialect::OpYamlInfoParser& op_yaml_info, - paddle::framework::RuntimeContext* runtime_ctx); +void BuildRuntimeContext(pir::Operation* op, + const ValueExecutionInfo& value_exec_info, + const paddle::dialect::OpYamlInfoParser& op_yaml_info, + RuntimeContext* runtime_ctx); -std::shared_ptr BuildOperatorBase( +std::shared_ptr BuildOperatorBase( pir::Operation* op, - const paddle::framework::ValueExecutionInfo& value_exec_info, + const ValueExecutionInfo& value_exec_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info); template -void BuildPhiContext( - pir::Operation* op, - const paddle::framework::ValueExecutionInfo& value_exec_info, - const paddle::dialect::OpYamlInfoParser& op_yaml_info, - Context* ctx) { - paddle::framework::Scope* inner_scope = value_exec_info.GetScope(); - VLOG(6) << "Build " << get_type_name() << "] inner_scope[" +void BuildPhiContext(pir::Operation* op, + const ValueExecutionInfo& value_exec_info, + const paddle::dialect::OpYamlInfoParser& op_yaml_info, + Context* ctx) { + Scope* inner_scope = value_exec_info.GetScope(); + VLOG(6) << "Build " << pir::get_type_name() << "] inner_scope[" << inner_scope << "]"; auto attr_map = op->attributes(); @@ -225,9 +217,9 @@ void BuildPhiContext( if (var->IsType()) { const phi::TensorBase* tensor_in = &(var->Get()); ctx->EmplaceBackInput(InType(tensor_in)); - } else if (var->IsType()) { + } else if (var->IsType()) { InListType inputs; - auto& variable_array = var->Get(); + auto& variable_array = var->Get(); for (size_t i = 0; i < variable_array.size(); ++i) { if (variable_array[i]->IsType()) { inputs.emplace_back(InType(const_cast( @@ -266,8 +258,8 @@ void BuildPhiContext( &(inner_scope->FindVar(in_var_name)->Get())); ctx->EmplaceBackAttr(attr); } else if (ptr.type().isa()) { - auto& tensor_array = inner_scope->FindVar(in_var_name) - ->Get(); + auto& tensor_array = + inner_scope->FindVar(in_var_name)->Get(); if (tensor_array.size() == 1) { phi::Attribute attr = phi::TensorRef(&(tensor_array[0]->Get())); @@ -441,7 +433,7 @@ void BuildPhiContext( OutListType outputs; auto& variable_array = inner_scope->FindVar(value_exec_info.GetVarName(out_ptr)) - ->Get(); + ->Get(); for (size_t i = 0; i < variable_array.size(); ++i) { if (variable_array[i]->IsType()) { outputs.emplace_back(OutType(const_cast( @@ -466,4 +458,5 @@ void BuildPhiContext( VLOG(6) << "Done build phi context"; } -} // namespace pir +} // namespace framework +} // namespace paddle From 9a4436941bf67b5516a35e2843c9f0d73a8fd860 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 27 Sep 2023 11:34:19 +0000 Subject: [PATCH 06/12] delete sub_blocks --- .../instruction/cond_instruction.cc | 18 +++++------ .../instruction/cond_instruction.h | 12 +++----- .../new_executor/new_ir_interpreter.cc | 14 +++------ .../new_executor/new_ir_interpreter.h | 2 -- .../pir_adaptor/pir_adaptor_util.cc | 25 ++++------------ .../pir_adaptor/pir_adaptor_util.h | 30 ++++++++----------- 6 files changed, 33 insertions(+), 68 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc index 76bd0e2c13782..ad6a55615012c 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc @@ -100,14 +100,10 @@ void GetOutsideOpInputs( } } -CondInstruction::CondInstruction( - size_t id, - const platform::Place& place, - pir::Operation* op, - Scope* scope, - Scope* local_scope, - ValueExecutionInfo* value_exec_info, - const std::map& sub_blocks) +CondInstruction::CondInstruction(size_t id, + const platform::Place& place, + pir::Operation* op, + ValueExecutionInfo* value_exec_info) : InstructionBase(id, place) { op_ = op; VLOG(6) << "finish process dist attributes"; @@ -115,7 +111,7 @@ CondInstruction::CondInstruction( SetKernelType(AnalyseOpFuncType(op, place)); VLOG(6) << "finish process analyse kernel type"; - Scope* inner_scope = local_scope == nullptr ? scope : local_scope; + Scope* inner_scope = value_exec_info->GetScope(); VLOG(6) << "finish process inputs outputs index"; @@ -140,7 +136,7 @@ CondInstruction::CondInstruction( auto true_branch_yied_inputs = GetYiedOpInputs(true_branch_block); auto false_branch_yied_inputs = GetYiedOpInputs(false_branch_block); - auto true_scope = sub_blocks.at(true_branch_block); + Scope* true_scope = &(value_exec_info->GetScope()->NewScope()); true_branch_inter = new NewIRInterpreter(place, {}, @@ -156,7 +152,7 @@ CondInstruction::CondInstruction( } true_branch_inter->SetSkipGcVars(true_skip_gc_names_set); - auto false_scope = sub_blocks.at(false_branch_block); + Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); false_branch_inter = new NewIRInterpreter(place, {}, diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h index 5c3879c49905f..43cfc78f0311d 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h @@ -29,14 +29,10 @@ class ValueExecutionInfo; class CondInstruction : public InstructionBase { public: - CondInstruction( - size_t id, - const platform::Place& place, - ::pir::Operation* op, - Scope* scope, - Scope* local_scope, - ValueExecutionInfo* value_exe_info, - const std::map& sub_blocks); + CondInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + ValueExecutionInfo* value_exe_info); void Run() override; diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 4d133135d4c17..230376b12c609 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -114,7 +114,7 @@ NewIRInterpreter::NewIRInterpreter( std::stringstream ss; ss << this; - BuildScope(*ir_block_, ss.str(), &sub_blocks_, value_exe_info_.get()); + BuildScope(*ir_block_, ss.str(), value_exe_info_.get()); } NewIRInterpreter::NewIRInterpreter( @@ -176,7 +176,7 @@ NewIRInterpreter::NewIRInterpreter( std::stringstream ss; ss << this; - BuildScope(*ir_block_, ss.str(), &sub_blocks_, value_exe_info_.get()); + BuildScope(*ir_block_, ss.str(), value_exe_info_.get()); } NewIRInterpreter::~NewIRInterpreter() { @@ -564,14 +564,8 @@ void NewIRInterpreter::BuildInstruction() { } else if (op->dialect()->name() == "cf") { continue; } else if (op->dialect()->name() == "pd_op") { - vec_instruction_base_.emplace_back( - std::make_unique(op_idx++, - place_, - op, - scope_, - local_scope_, - value_exe_info_.get(), - sub_blocks_)); + vec_instruction_base_.emplace_back(std::make_unique( + op_idx++, place_, op, value_exe_info_.get())); } else if (op->dialect()->name() == "pd_kernel") { auto op_name = op->attributes() .at("op_name") diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index 04a149bb6d692..94975b366c17d 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -216,8 +216,6 @@ class NewIRInterpreter : public InterpreterBaseImpl { // value execution info std::shared_ptr value_exe_info_; - std::map sub_blocks_; - std::vector var_ref_count_; interpreter::NewIrDependencyBuilder ir_dependency_builder_; diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index d0abdf02e2a83..53db667f3501d 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -14,7 +14,6 @@ #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" -#include "glog/logging.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" @@ -298,7 +297,6 @@ Variable* CreateVar(pir::Value value, } Variable* var = nullptr; - std::string name = var_name_prefix + "_inner_var_" + std::to_string(value_exe_info->GetVar2VarName().size()); @@ -329,7 +327,7 @@ void CheckInputVars(pir::Operation* op, execution_info->HasValue(value), true, phi::errors::PreconditionNotMet( - "input should in name map, [%d] 'th input of [%s] op", + "input should in execution_info, [%d] 'th input of [%s] op", i, op_name)); } @@ -374,14 +372,14 @@ void BuildValue(pir::Value value, tensor_array->emplace_back(var_i); } } else { - PADDLE_THROW(phi::errors::PreconditionNotMet( - "Output only support DenseTensorType or VectorType")); + PADDLE_THROW( + phi::errors::PreconditionNotMet("Output only support DenseTensorType " + "or SelectedRowsType or VectorType")); } } void HandleForSpecialOp(pir::Operation* op, const std::string& var_name_prefix, - std::map* sub_blocks, ValueExecutionInfo* value_exe_info) { std::string op_name = op->name(); if (op->attributes().count("op_name")) { @@ -555,20 +553,8 @@ void HandleForSpecialOp(pir::Operation* op, if (op_name == "pd_op.if") { auto if_op = op->dyn_cast(); - - auto true_block = if_op.true_block(); - - auto false_block = if_op.false_block(); - - auto& true_branch_scope = value_exe_info->GetScope()->NewScope(); - sub_blocks->emplace(true_block, &true_branch_scope); - - auto& false_branch_scope = value_exe_info->GetScope()->NewScope(); - sub_blocks->emplace(false_block, &false_branch_scope); - for (size_t i = 0; i < if_op->num_results(); ++i) { // auto true_value = true_yeid_op->operand_source(i); - auto if_op_out_value = if_op->result(i); BuildValue(if_op_out_value, var_name_prefix, value_exe_info); } @@ -628,7 +614,6 @@ void HandleForInplaceOp(pir::Operation* op, // is created in inner_scope. void BuildScope(const pir::Block& block, const std::string& var_name_prefix, - std::map* sub_blocks, ValueExecutionInfo* value_exe_info) { VLOG(4) << "***** [before build] scope" << "(" << value_exe_info->GetScope() << ") ******\n" @@ -645,7 +630,7 @@ void BuildScope(const pir::Block& block, } VLOG(4) << "build op:" << op_name; if (SpecialOps.count(op_name)) { - HandleForSpecialOp(op, var_name_prefix, sub_blocks, value_exe_info); + HandleForSpecialOp(op, var_name_prefix, value_exe_info); continue; } diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index 2617dcd45e5f4..87603f2f14b15 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -14,31 +14,29 @@ #pragma once -#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" -#include "paddle/fluid/pir/dialect/operator/utils/utils.h" -#include "paddle/phi/core/meta_tensor.h" -#include "paddle/pir/core/builtin_attribute.h" -#include "paddle/pir/core/ir_context.h" -#include "paddle/pir/core/program.h" -#include "paddle/pir/core/utils.h" - #include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable_helper.h" -#include "paddle/phi/core/kernel_context.h" - -#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/kernel_context.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" #include "paddle/pir/core/type_name.h" +#include "paddle/pir/core/utils.h" #include "glog/logging.h" @@ -130,8 +128,7 @@ class ValueExecutionInfo { std::unordered_map<::pir::Value, std::string> value_2_var_name_; - std::unordered_map - var_2_var_name_; + std::unordered_map var_2_var_name_; std::map var_name_2_id_; @@ -153,7 +150,6 @@ inline bool IsInvalid(pir::Value value) { void BuildScope(const pir::Block& block, const std::string& var_name_prefix, - std::map* sub_blocks, ValueExecutionInfo* value_exe_info = nullptr); void BuildRuntimeContext(pir::Operation* op, From ba8caaa0805c7c4f413fb8275a524a9c68d3c581 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 27 Sep 2023 11:48:16 +0000 Subject: [PATCH 07/12] refine --- .../framework/new_executor/instruction/cond_instruction.h | 6 ++++-- paddle/fluid/framework/new_executor/new_ir_interpreter.cc | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h index 43cfc78f0311d..4f62e76caa67a 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h @@ -44,6 +44,8 @@ class CondInstruction : public InstructionBase { void CopyBranchOutput(const std::vector& var_names, const NewIRInterpreter* inter); + ::pir::Operation* op_; + std::string cond_name_{"cond_instruction"}; Variable* cond_var; @@ -51,12 +53,12 @@ class CondInstruction : public InstructionBase { std::vector if_op_outputs_; NewIRInterpreter* true_branch_inter; + NewIRInterpreter* false_branch_inter; std::vector true_skip_gc_names_; - std::vector false_skip_gc_names_; - ::pir::Operation* op_; + std::vector false_skip_gc_names_; }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 230376b12c609..308b1265fc63c 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -47,6 +47,7 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/core/builtin_attribute.h" @@ -558,12 +559,13 @@ void NewIRInterpreter::BuildInstruction() { VLOG(6) << "Build Instruction for op: " << op_idx; if (op->dialect()->name() == "builtin") { if (interpreter::GetSpecialOpNames().count(op->name())) { - VLOG(6) << "skip process " << op->name(); + VLOG(6) << "skip process builtin dialect op: " << op->name(); continue; } } else if (op->dialect()->name() == "cf") { + VLOG(6) << "skip process cf dialect op: " << op->name(); continue; - } else if (op->dialect()->name() == "pd_op") { + } else if (op->isa()) { vec_instruction_base_.emplace_back(std::make_unique( op_idx++, place_, op, value_exe_info_.get())); } else if (op->dialect()->name() == "pd_kernel") { @@ -577,7 +579,7 @@ void NewIRInterpreter::BuildInstruction() { } VLOG(6) << "process " << op_name; - if (op->name().compare(paddle::dialect::LegacyKernelOp::name()) == 0) { + if (op->isa()) { vec_instruction_base_.emplace_back( std::make_unique( op_idx++, place_, op, *(value_exe_info_.get()))); From 374710f9ffbefbc598d54869cc9d10636f7693d0 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 27 Sep 2023 12:46:41 +0000 Subject: [PATCH 08/12] refien --- .../instruction/cond_instruction.cc | 74 +++++++++---------- .../instruction/cond_instruction.h | 10 ++- .../instruction/instruction_util.cc | 60 +++++++-------- .../new_executor/new_ir_interpreter.cc | 2 +- .../new_executor/new_ir_interpreter.h | 4 +- 5 files changed, 73 insertions(+), 77 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc index ad6a55615012c..67e1bc409f9f4 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc @@ -47,7 +47,6 @@ std::vector GetYiedOpInputs(pir::Block* block) { } } } - return vec_res; } @@ -56,7 +55,7 @@ void GetInputIds(pir::Operation* op, std::unordered_map>* input_ids) { for (size_t i = 0; i < op->num_operands(); i++) { pir::Value value = op->operand_source(i); - if (value) { + if (value && value.type()) { PADDLE_ENFORCE_EQ( value_exec_info.HasValue(value), true, @@ -105,39 +104,28 @@ CondInstruction::CondInstruction(size_t id, pir::Operation* op, ValueExecutionInfo* value_exec_info) : InstructionBase(id, place) { - op_ = op; - VLOG(6) << "finish process dist attributes"; - - SetKernelType(AnalyseOpFuncType(op, place)); - VLOG(6) << "finish process analyse kernel type"; - - Scope* inner_scope = value_exec_info->GetScope(); - - VLOG(6) << "finish process inputs outputs index"; - PADDLE_ENFORCE( op->isa(), phi::errors::PreconditionNotMet("Cond instruction only support if op")); - auto if_op = op->dyn_cast(); + op_ = op; + + SetKernelType(AnalyseOpFuncType(op, place)); + VLOG(6) << "finish process analyse kernel type"; + auto cond_value = if_op.operand_source(0); + cond_var_ = value_exec_info->GetScope()->FindVar( + value_exec_info->GetValue2VarName().at(cond_value)); for (size_t i = 0; i < if_op.num_results(); ++i) { - if_op_outputs_.push_back(inner_scope->GetVar( + output_vars_.push_back(value_exec_info->GetScope()->GetVar( value_exec_info->GetValue2VarName().at(if_op.result(i)))); } - - auto cond_value = if_op.operand_source(0); - auto var_name = value_exec_info->GetValue2VarName().at(cond_value); - cond_var = inner_scope->FindVar(var_name); + VLOG(6) << "finish process cond_var and output_vars"; auto true_branch_block = if_op.true_block(); - auto false_branch_block = if_op.false_block(); - auto true_branch_yied_inputs = GetYiedOpInputs(true_branch_block); - auto false_branch_yied_inputs = GetYiedOpInputs(false_branch_block); - Scope* true_scope = &(value_exec_info->GetScope()->NewScope()); - true_branch_inter = + true_branch_inter_ = new NewIRInterpreter(place, {}, true_branch_block, @@ -147,13 +135,16 @@ CondInstruction::CondInstruction(size_t id, std::set true_skip_gc_names_set; for (auto value : true_branch_yied_inputs) { - true_skip_gc_names_.push_back(true_branch_inter->GetNameByValue(value)); - true_skip_gc_names_set.insert(true_branch_inter->GetNameByValue(value)); + true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value)); + true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value)); } - true_branch_inter->SetSkipGcVars(true_skip_gc_names_set); + true_branch_inter_->SetSkipGcVars(true_skip_gc_names_set); + VLOG(6) << "finish process true branch interpreter"; + auto false_branch_block = if_op.false_block(); + auto false_branch_yied_inputs = GetYiedOpInputs(false_branch_block); Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); - false_branch_inter = + false_branch_inter_ = new NewIRInterpreter(place, {}, false_branch_block, @@ -163,19 +154,19 @@ CondInstruction::CondInstruction(size_t id, std::set false_skip_gc_names_set; for (auto value : false_branch_yied_inputs) { - false_skip_gc_names_.push_back(false_branch_inter->GetNameByValue(value)); - false_skip_gc_names_set.insert(false_branch_inter->GetNameByValue(value)); + false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); } - false_branch_inter->SetSkipGcVars(false_skip_gc_names_set); - - // the true branch and false branch input will be the if_op inputs + false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); + VLOG(6) << "finish process false branch interpreter"; + // NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is + // OpOperand of IfOp, and the other is external Values used in true_block or + // false_block. std::unordered_map> inputs; - GetInputIds(op, *value_exec_info, &inputs); GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs); GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs); - SetInputs(inputs); std::unordered_map> outputs; @@ -194,26 +185,27 @@ CondInstruction::CondInstruction(size_t id, } } SetOutputs(outputs); + VLOG(6) << "finish process inputs outputs index"; } void CondInstruction::CopyBranchOutput( const std::vector& var_names, const NewIRInterpreter* inter) { for (size_t i = 0; i < var_names.size(); ++i) { - auto* inner_var = inter->local_scope()->GetVar(var_names[i]); + auto* inner_var = inter->InnerScope()->GetVar(var_names[i]); - if_op_outputs_[i]->GetMutable()->ShareDataWith( + output_vars_[i]->GetMutable()->ShareDataWith( inner_var->Get()); } } void CondInstruction::Run() { DeviceContext().Wait(); - if (cond_var->Get().data()[0]) { - true_branch_inter->Run({}, false); - CopyBranchOutput(true_skip_gc_names_, true_branch_inter); + if (cond_var_->Get().data()[0]) { + true_branch_inter_->Run({}, false); + CopyBranchOutput(true_skip_gc_names_, true_branch_inter_); } else { - false_branch_inter->Run({}, false); - CopyBranchOutput(false_skip_gc_names_, false_branch_inter); + false_branch_inter_->Run({}, false); + CopyBranchOutput(false_skip_gc_names_, false_branch_inter_); } // copy ouptut diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h index 4f62e76caa67a..974784dbe982a 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h @@ -48,14 +48,16 @@ class CondInstruction : public InstructionBase { std::string cond_name_{"cond_instruction"}; - Variable* cond_var; + Variable* cond_var_; - std::vector if_op_outputs_; + std::vector output_vars_; - NewIRInterpreter* true_branch_inter; + NewIRInterpreter* true_branch_inter_; - NewIRInterpreter* false_branch_inter; + NewIRInterpreter* false_branch_inter_; + // TODO(zhangbo): Currently, only the output of IfOp is included. In the + // future, need to consider how to support IfGradOp using IfOp value. std::vector true_skip_gc_names_; std::vector false_skip_gc_names_; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index f888c0d514d51..a9791fefdbc38 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -142,46 +142,48 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) { return OpFuncType::kCpuSync; } - if (op->dialect()->name() == "pd_op") { - return OpFuncType::kGpuAsync; - } - - auto kernel_key = op->attributes() - .at("kernel_key") - .dyn_cast() - .data(); - if (phi::TransToPhiPlace(kernel_key.backend()).GetType() == - phi::AllocationType::CPU) { - return OpFuncType::kCpuSync; - } - PADDLE_ENFORCE_EQ(interpreter::IsSupportedHeterPlace(place), true, phi::errors::Fatal("Unsupported current place %s", place)); + auto& op_attributes = op->attributes(); + + if ((op->dialect()->name() == "pd_kernel") && + (op_attributes.count("kernel_key") > 0)) { + auto kernel_key = op_attributes.at("kernel_key") + .dyn_cast() + .data(); + if (phi::TransToPhiPlace(kernel_key.backend()).GetType() == + phi::AllocationType::CPU) { + return OpFuncType::kCpuSync; + } + } + // Some GPU OPs do not launch CUDA Kernel, but spend a lot of time on CPU // computing. They execute serially in device thread and block CUDA kernel // launching in other GPU OPs. To improve performance, set them as kGpuSync // and so that they would be dispatched to host thread. - auto& op_attributes = op->attributes(); - auto op_name = - op_attributes.at("op_name").dyn_cast().AsString(); - if (op_name == "pd_op.coalesce_tensor" && - (!platform::is_xpu_place(place) || - op->attribute("persist_output").data() == false) && - op->attribute("set_constant").data() == false && - op->attribute("copy_data").data() == false) { - return OpFuncType::kGpuSync; - } + if ((op->dialect()->name() == "pd_kernel") && + (op_attributes.count("op_name") > 0)) { + auto op_name = + op_attributes.at("op_name").dyn_cast().AsString(); + if (op_name == "pd_op.coalesce_tensor" && + (!platform::is_xpu_place(place) || + op->attribute("persist_output").data() == false) && + op->attribute("set_constant").data() == false && + op->attribute("copy_data").data() == false) { + return OpFuncType::kGpuSync; + } - // for memcpy explicitly called by user - if (platform::is_gpu_place(place) && op_name == "pd_op.memcpy_d2h") { - return OpFuncType::kGpuSync; - } + if (platform::is_gpu_place(place) && op_name == "pd_op.memcpy_d2h") { + return OpFuncType::kGpuSync; + } - if (op_name == "pd_op.shape") { - return OpFuncType::kGpuSync; + if (op_name == "pd_op.shape") { + return OpFuncType::kGpuSync; + } } + return OpFuncType::kGpuAsync; } diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 308b1265fc63c..e8508094f4448 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -380,7 +380,7 @@ std::string NewIRInterpreter::GetDepsString() const { bool NewIRInterpreter::HasLocalScope() const { return local_scope_ != nullptr; } -Scope* NewIRInterpreter::InnerScope() { +Scope* NewIRInterpreter::InnerScope() const { return local_scope_ != nullptr ? local_scope_ : scope_; } diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index 94975b366c17d..3a128791cdfce 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -78,6 +78,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { const Scope* local_scope() const override; + Scope* InnerScope() const; + const platform::Place& GetPlace() const override { return place_; } void SetOutputHooks(const std::vector& hookfuncs) override { @@ -115,8 +117,6 @@ class NewIRInterpreter : public InterpreterBaseImpl { // scope bool HasLocalScope() const; - Scope* InnerScope(); - // For log and debug std::string GetDepsString() const; From 4af0d5a525ad3a79035721d1b29a7830d5b80caa Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 28 Sep 2023 06:50:53 +0000 Subject: [PATCH 09/12] add ut --- test/legacy_test/test_cond.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/legacy_test/test_cond.py b/test/legacy_test/test_cond.py index 55e6f8116cf33..b5ee90871f67f 100644 --- a/test/legacy_test/test_cond.py +++ b/test/legacy_test/test_cond.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest import numpy as np from simple_nets import batchnorm_fc_with_inputs, simple_fc_net_with_inputs +sys.path.append("../dygraph_to_static") +from dygraph_to_static_util import test_and_compare_with_new_ir + import paddle from paddle import base from paddle.base import core, framework @@ -27,6 +31,7 @@ class TestCondInputOutput(unittest.TestCase): + @test_and_compare_with_new_ir() def test_return_single_var(self): """ pseudocode: @@ -73,6 +78,7 @@ def false_func(): np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05 ) + @test_and_compare_with_new_ir() def test_return_0d_tensor(self): """ pseudocode: @@ -110,6 +116,7 @@ def false_func(): np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05) self.assertEqual(ret.shape, ()) + @test_and_compare_with_new_ir() def test_0d_tensor_as_cond(self): """ pseudocode: @@ -210,6 +217,7 @@ def test_0d_tensor_dygraph(self): ) self.assertEqual(a.grad.shape, []) + @test_and_compare_with_new_ir() def test_return_var_tuple(self): """ pseudocode: @@ -347,6 +355,7 @@ def false_func(): self.assertIsNone(out2) self.assertIsNone(out3) + @test_and_compare_with_new_ir() def test_wrong_structure_exception(self): """ test returning different number of tensors cannot merge into output From df3f78194d0c24a4cbef0204a71f9c2b10040087 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 9 Oct 2023 02:04:10 +0000 Subject: [PATCH 10/12] fix --- paddle/fluid/framework/new_executor/new_ir_interpreter.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 21817c2f443d5..b6b2be142a0dd 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -47,6 +47,7 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/core/builtin_attribute.h" From 2c659b9f8da408f1db2821e652d2c4260ca0d680 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 9 Oct 2023 06:57:53 +0000 Subject: [PATCH 11/12] support if with inplace --- paddle/fluid/pir/transforms/inplace_pass.cc | 30 ++++++++++++++++----- test/legacy_test/test_cond.py | 1 + 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index 6010af208fae6..f1566725f9326 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" @@ -121,11 +122,10 @@ static std::unordered_set GetSkipDeletionValues(pir::Block* block) { // NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator // is supported. Therefore, this function only returns the values in the // kernel_dialect operator that can be eager deleted. -static std::unordered_map> -GetEagerDeletionValues(pir::Block* block) { - std::unordered_set skip_dels = GetSkipDeletionValues(block); - - std::unordered_map del_value_2_op; +static void GetEagerDelValueOfOp( + pir::Block* block, + const std::unordered_set& skip_dels, + std::unordered_map* del_value_2_op) { for (auto& op : *block) { std::string upper_op_name = op->name(); if (op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) == @@ -150,16 +150,32 @@ GetEagerDeletionValues(pir::Block* block) { VLOG(8) << " -- is no_need_buffer: " << IsNoNeedBuffer(op, input); continue; } - del_value_2_op[input] = op; + (*del_value_2_op)[input] = op; } for (size_t i = 0; i < op->num_results(); ++i) { pir::Value output = op->result(i); if (output && CanBeDeleted(output)) { - del_value_2_op[output] = op; + (*del_value_2_op)[output] = op; } } + + if (op->isa()) { + auto if_op = op->dyn_cast(); + GetEagerDelValueOfOp(if_op.true_block(), skip_dels, del_value_2_op); + VLOG(8) << "GetEagerDelValueOfOp for IfOp true block"; + GetEagerDelValueOfOp(if_op.false_block(), skip_dels, del_value_2_op); + VLOG(8) << "GetEagerDelValueOfOp for IfOp false block"; + } } +} + +static std::unordered_map> +GetEagerDeletionValues(pir::Block* block) { + std::unordered_set skip_dels = GetSkipDeletionValues(block); + + std::unordered_map del_value_2_op; + GetEagerDelValueOfOp(block, skip_dels, &del_value_2_op); std::unordered_map> eager_dels; diff --git a/test/legacy_test/test_cond.py b/test/legacy_test/test_cond.py index b5ee90871f67f..cec7664ae6cb6 100644 --- a/test/legacy_test/test_cond.py +++ b/test/legacy_test/test_cond.py @@ -265,6 +265,7 @@ def false_func(): np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05 ) + @test_and_compare_with_new_ir() def test_pass_and_modify_var(self): """ pseudocode: From beaabff52e5b60b1557aa027dfa9bedc4771ab1e Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 10 Oct 2023 02:52:24 +0000 Subject: [PATCH 12/12] fix bug --- .../instruction/cond_instruction.cc | 19 ++++++++++++------- .../instruction/cond_instruction.h | 2 ++ .../instruction/instruction_util.cc | 7 +++++-- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc index 7a037f0983c64..bcdb7abb7eec1 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc @@ -64,8 +64,7 @@ void GetInputIds(pir::Operation* op, "input should in name map, [%d] 'th input of [%s] op", i, "if op")); - std::vector inputs_id = GetValueIds(value, value_exec_info); - input_ids->emplace(value, inputs_id); + input_ids->emplace(value, GetValueIds(value, value_exec_info)); } } } @@ -92,9 +91,7 @@ void GetOutsideOpInputs( "input should in name map, [%d] 'th input of [%s] op", i, op->name())); - std::vector inputs_id = GetValueIds(value, value_exec_info); - - input_ids->emplace(value, inputs_id); + input_ids->emplace(value, GetValueIds(value, value_exec_info)); } } } @@ -181,14 +178,22 @@ CondInstruction::CondInstruction(size_t id, "input should in name map, [%d] 'th input of [%s] op", i, "if op")); - std::vector outputs_id = GetValueIds(value, *value_exec_info); - outputs.emplace(value, outputs_id); + outputs.emplace(value, GetValueIds(value, *value_exec_info)); } } SetOutputs(outputs); VLOG(6) << "finish process inputs outputs index"; } +CondInstruction::~CondInstruction() { + if (true_branch_inter_ != nullptr) { + delete true_branch_inter_; + } + if (false_branch_inter_ != nullptr) { + delete false_branch_inter_; + } +} + void CondInstruction::CopyBranchOutput( const std::vector& var_names, const NewIRInterpreter* inter) { for (size_t i = 0; i < var_names.size(); ++i) { diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h index 974784dbe982a..1cdc4a388126a 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h @@ -34,6 +34,8 @@ class CondInstruction : public InstructionBase { ::pir::Operation* op, ValueExecutionInfo* value_exe_info); + ~CondInstruction(); + void Run() override; const std::string& Name() const override { return cond_name_; } diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index a9791fefdbc38..6459c094a8e7b 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -20,6 +20,8 @@ #include #include "paddle/fluid/framework/new_executor/new_executor_defs.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/event.h" #include "paddle/pir/core/builtin_attribute.h" @@ -148,7 +150,8 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) { auto& op_attributes = op->attributes(); - if ((op->dialect()->name() == "pd_kernel") && + if ((op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) == + 0) && (op_attributes.count("kernel_key") > 0)) { auto kernel_key = op_attributes.at("kernel_key") .dyn_cast() @@ -179,7 +182,7 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) { return OpFuncType::kGpuSync; } - if (op_name == "pd_op.shape") { + if (op_name.compare(paddle::dialect::ShapeOp::name()) == 0) { return OpFuncType::kGpuSync; } }