diff --git a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt index 8d9a93757d3099..93356e6b217a0f 100644 --- a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt @@ -1,4 +1,5 @@ cc_library( instruction_base SRCS instruction_base.cc phi_kernel_instruction.cc + legacy_kernel_instruction.cc instruction_util.cc DEPS phi framework_proto) diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc index 6c09d7aa2a13fd..56dafd3132c030 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc @@ -13,9 +13,15 @@ // limitations under the License. #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/ir/core/builtin_attribute.h" + namespace paddle { namespace framework { @@ -93,5 +99,59 @@ void InstructionBase::SetOutputs( output_index_ = outputs; } +void InstructionBase::InitInputsOutputsIds( + ::ir::Operation* op, + Scope* inner_scope, + const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::map& var_name_2_id, + const std::unordered_map& + variable_2_var_name) { + auto op_attributes = op->attributes(); + auto op_name = + op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); + std::unordered_map> inputs; + for (size_t i = 0; i < op->num_operands(); i++) { + ir::Value value = op->operand_source(i); + if (value) { + PADDLE_ENFORCE_NE( + value_2_var_name.find(value), + value_2_var_name.end(), + phi::errors::PreconditionNotMet( + "input should in name map, [%d] 'th input of [%s] op", + i, + op_name)); + std::vector inputs_id = GetValueIds(value, + inner_scope, + value_2_var_name, + var_name_2_id, + variable_2_var_name); + inputs.emplace(value, inputs_id); + } + } + SetInputs(inputs); + VLOG(8) << "finish process inputs_index"; + std::unordered_map> outputs; + for (size_t i = 0; i < op->num_results(); i++) { + ir::Value value = op->result(i); + if (value && value.type()) { + PADDLE_ENFORCE_NE( + value_2_var_name.find(value), + value_2_var_name.end(), + phi::errors::PreconditionNotMet( + "input should in name map, [%d] 'th input of [%s] op", + i, + op_name)); + std::vector outputs_id = GetValueIds(value, + inner_scope, + value_2_var_name, + var_name_2_id, + variable_2_var_name); + outputs.emplace(value, outputs_id); + } + } + SetOutputs(outputs); + VLOG(8) << "finish process outputs_index"; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.h b/paddle/fluid/framework/new_executor/instruction/instruction_base.h index 7452990a1d9076..f078da97107e7e 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.h @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/platform/event.h" +#include "paddle/ir/core/value.h" namespace ir { class Value; @@ -137,7 +138,15 @@ class InstructionBase { virtual const std::string& Name() const = 0; - private: + void InitInputsOutputsIds( + ::ir::Operation* op, + Scope* inner_scope, + const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::map& var_name_2_id, + const std::unordered_map& + variable_2_var_name); + + protected: size_t id_; bool is_artificial_; // Instruction is artificial means that it is only used diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc new file mode 100644 index 00000000000000..d8ddc30633be07 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" + +#include "paddle/fluid/framework/new_executor/new_executor_defs.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/event.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/value.h" + +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" +#include "paddle/fluid/platform/collective_helper.h" + +namespace paddle { +namespace framework { + +std::vector GetValueIds( + ir::Value value, + Scope* inner_scope, + const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::map& var_name_2_id, + const std::unordered_map& + variable_2_var_name) { + std::vector ids; + std::string var_name = value_2_var_name.at(value); + ids.push_back(var_name_2_id.at(var_name)); + // NOTE(zhangbo): Value maybe a VariableRefArray + auto var = inner_scope->FindVar(var_name); + if (var->IsType()) { + auto& var_array = var->Get(); + for (auto item : var_array) { + ids.push_back(var_name_2_id.at(variable_2_var_name.at(item))); + } + } + return ids; +} + +platform::DeviceContext* ParseDeviceContext( + ir::Operation* op, + platform::DeviceContext* origin_dev_ctx, + const platform::Place& place, + const std::string& execution_stream, + const int stream_priority) { + auto op_attributes = op->attributes(); + auto op_name = + op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); + interpreter::ContextManager& ctx_manager = + interpreter::ContextManager::Instance(); + + platform::DeviceContext* dev_ctx = nullptr; + + // only gpu need update. xpu not need, because xpu memcpy op kernel is + // synchronous. + if (platform::is_gpu_place(place) || platform::is_custom_place(place)) { + VLOG(6) << "Parse DeviceContext for " << op_name + << ", execution stream = " << execution_stream; + if (execution_stream != kDefaultStream) { + dev_ctx = ctx_manager + .Get(std::string(kCustomStream) + "-" + execution_stream, + place, + stream_priority) + .get() + .get(); + interpreter::SetDeviceCommContext(op, dev_ctx); + return dev_ctx; + } + + if (op_name == interpreter::kMemcpyD2H) { + dev_ctx = ctx_manager.Get(std::string(kD2HStream), place, stream_priority) + .get() + .get(); + interpreter::SetDeviceCommContext(op, dev_ctx); + return dev_ctx; + } else if (op_name == interpreter::kMemcpyH2D) { + dev_ctx = ctx_manager.Get(std::string(kH2DStream), place, stream_priority) + .get() + .get(); + interpreter::SetDeviceCommContext(op, dev_ctx); + return dev_ctx; + } + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + // NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum + // with use_cal_stream==false by returning a device context getting from the + // global NCCLCommContext instance. Because when use_calc_stream==false, in + // OP kernel, the NCCL communication will be launched to the stream directly + // getting from the global NCCLCommContext instance rather than the + // DeviceContext passed from executor (see CAllReduceOpCUDAKernel in + // c_allreduce_op.h). Now it is just a temporary solution for ONLY + // c_allreduce_sum which is used in ResNet50 distributed training. + if (op_name == "c_allreduce_sum" && op_attributes.at("use_calc_stream") + .dyn_cast<::ir::BoolAttribute>() + .data() == false) { + int ring_id = + op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data(); + return platform::NCCLCommContext::Instance() + .Get(ring_id, place) + ->dev_context(); + } +#endif + } + + if (origin_dev_ctx != nullptr) { + interpreter::SetDeviceCommContext(op, origin_dev_ctx); + } + return origin_dev_ctx; +} + +OpFuncType AnalyseOpFuncType(::ir::Operation* op, + const platform::Place& place) { + if (platform::is_cpu_place(place)) { + return OpFuncType::kCpuSync; + } + + auto kernel_key = op->attributes() + .at("kernel_key") + .dyn_cast() + .data(); + if (phi::TransToPhiPlace(kernel_key.backend()).GetType() == + phi::AllocationType::CPU) { + return OpFuncType::kCpuSync; + } + + PADDLE_ENFORCE_EQ(interpreter::IsSupportedHeterPlace(place), + true, + phi::errors::Fatal("Unsupported current place %s", place)); + + // Some GPU OPs do not launch CUDA Kernel, but spend a lot of time on CPU + // computing. They execute serially in device thread and block CUDA kernel + // launching in other GPU OPs. To improve performance, set them as kGpuSync + // and so that they would be dispatched to host thread. + auto op_attributes = op->attributes(); + auto op_name = + op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); + if (op_name == kCoalesceTensor && + (!platform::is_xpu_place(place) || + op->attribute("persist_output").data() == false) && + op->attribute("set_constant").data() == false && + op->attribute("copy_data").data() == false) { + return OpFuncType::kGpuSync; + } + + // for memcpy explicitly called by user + if (platform::is_gpu_place(place) && op_name == interpreter::kMemcpyD2H) { + return OpFuncType::kGpuSync; + } + + if (op_name == "shape") { + return OpFuncType::kGpuSync; + } + return OpFuncType::kGpuAsync; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.h b/paddle/fluid/framework/new_executor/instruction/instruction_util.h new file mode 100644 index 00000000000000..a41ce07957e4ae --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/new_executor/new_executor_defs.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/event.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/value.h" +namespace paddle { +namespace framework { + +std::vector GetValueIds( + ir::Value value, + Scope* inner_scope, + const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::map& var_name_2_id, + const std::unordered_map& + variable_2_var_name); + +platform::DeviceContext* ParseDeviceContext( + ir::Operation* op, + platform::DeviceContext* origin_dev_ctx, + const platform::Place& place, + const std::string& execution_stream, + const int stream_priority); + +OpFuncType AnalyseOpFuncType(::ir::Operation* op, const platform::Place& place); + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc new file mode 100644 index 00000000000000..eadf0c1f806cf4 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h" + +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/ir/dialect/pd_dialect.h" +#include "paddle/fluid/ir/interface/infermeta.h" +#include "paddle/fluid/ir/interface/op_yaml_info.h" +#include "paddle/fluid/ir/interface/op_yaml_info_parser.h" +#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/type_defs.h" + +namespace paddle { +namespace framework { + +LegacyKernelInstruction::LegacyKernelInstruction( + size_t id, + const platform::Place& place, + ir::Operation* op, + Scope* scope, + Scope* local_scope, + const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::map& var_name_2_id, + const std::unordered_map& + variable_2_var_name) + : InstructionBase(id, place) { + auto op_attributes = op->attributes(); + auto op_name = + op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); + ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + + legacy_op_name_ = op_name; + VLOG(6) << "construct phi kernel instruction for: " << legacy_op_name_; + + // Todo: support paddle::dialect::DistAttribute + // if (op_attributes.count("dist_attr") != 0) { + // if (op_attributes.count("execution_stream") != 0) { + // SetExecutionStream(op_attributes.at("execution_stream") + // .dyn_cast<::ir::StrAttribute>() + // .data()); + // } + // if (op_attributes.count("stream_priority") != 0) { + // SetStreamPriority(op_attributes.at("stream_priority") + // .dyn_cast<::ir::Int32Attribute>() + // .data()); + // } + // if (op_attributes.count("scheduling_priority") != 0) { + // SetSchedulingPriority(op_attributes.at("scheduling_priority") + // .dyn_cast<::ir::Int64Attribute>() + // .data()); + // } + // } else { + // if (interpreter::IsCommunicationOp(op)) { + // // NOTE(Ruibiao): Dispatching computation before communication + // improves + // // multi-stream overlap when the time cost of communication less than + // // that of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32 + // // training). + // op_func_node.scheduling_priority_ = 1; + // } + // } + VLOG(6) << "finish process dist attributes"; + + SetKernelType(AnalyseOpFuncType(op, place)); + VLOG(6) << "finish process analyse kernel type"; + + infer_meta_interface_ = + op_info.GetInterfaceImpl(); + VLOG(6) << "finish process infer_meta_interface_"; + + auto yaml_interface = + op_info.GetInterfaceImpl(); + PADDLE_ENFORCE_NOT_NULL( + yaml_interface, + phi::errors::PreconditionNotMet( + "can not find OpYamlInfoInterface from [%s]", legacy_op_name_)); + paddle::dialect::OpYamlInfoParser yaml_info_parser( + yaml_interface->get_op_info_()); + VLOG(6) << "finish process yaml_info_parser"; + + ::ir::BuildPhiContext< + phi::InferMetaContext, + phi::MetaTensor, + phi::MetaTensor, + paddle::small_vector, + paddle::small_vector, + false>(op, + value_2_var_name, + scope, + local_scope, + yaml_info_parser, + &infer_meta_context_); + VLOG(6) << "finish process infer meta context"; + + auto kernel_name = + op_attributes.at("kernel_name").dyn_cast().AsString(); + auto kernel_key = op_attributes.at("kernel_key") + .dyn_cast() + .data(); + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( + kernel_name, kernel_key); + phi_kernel_ = new phi::Kernel(kernel_result.kernel); + PADDLE_ENFORCE_EQ( + phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name); + VLOG(6) << "finish process select kernel"; + + operator_base_ = + ir::BuildOperatorBase(op, value_2_var_name, yaml_info_parser); + paddle::framework::VariableValueMap in_map; + paddle::framework::VariableValueMap out_map; + auto dev_ctx = phi::DeviceContextPool::Instance().Get( + phi::TransToPhiPlace(kernel_key.backend())); + + runtime_context_ = std::make_shared( + paddle::framework::RuntimeContext(in_map, out_map)); + ir::BuildRuntimeContext(op, + value_2_var_name, + scope, + local_scope, + yaml_info_parser, + runtime_context_.get()); + kernel_context_ = new paddle::framework::ExecutionContext( + *operator_base_, *local_scope, *dev_ctx, *(runtime_context_.get())); + + VLOG(6) << "finish process kernel context"; + SetDeviceContext( + ParseDeviceContext(op, + phi::DeviceContextPool::Instance().Get( + phi::TransToPhiPlace(kernel_key.backend())), + place, + GetExecutionStream(), + GetStreamPriority())); + VLOG(6) << "finish process device context"; + + Scope* inner_scope = local_scope == nullptr ? scope : local_scope; + InitInputsOutputsIds( + op, inner_scope, value_2_var_name, var_name_2_id, variable_2_var_name); + VLOG(6) << "finish process inputs outputs index"; + + auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); + std::unordered_set<::ir::Value> no_need_buffer_values; + for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { + no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id])); + } + SetNoNeedBuffer(no_need_buffer_values); + VLOG(6) << "finish process no need buffer"; +} + +LegacyKernelInstruction::~LegacyKernelInstruction() { + if (kernel_context_ != nullptr) { + delete kernel_context_; + } +} + +void LegacyKernelInstruction::Run() { + infer_meta_interface_->infer_meta_(&(infer_meta_context_)); + VLOG(6) << "Run op " << legacy_op_name_ << " infer meta."; + (*(phi_kernel_))((kernel_context_)); + VLOG(6) << "Run op " << legacy_op_name_ << " kernel."; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h new file mode 100644 index 00000000000000..27c1cb133bec01 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h @@ -0,0 +1,72 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +namespace ir { +class Operation; +class Value; +} // namespace ir + +namespace paddle { +namespace framework { +class Scope; + +class LegacyKernelInstruction : public InstructionBase { + public: + LegacyKernelInstruction( + size_t id, + const platform::Place& place, + ::ir::Operation* op, + Scope* scope, + Scope* local_scope, + const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::map& var_name_2_id, + const std::unordered_map& + variable_2_var_name); + + ~LegacyKernelInstruction(); + phi::Kernel* PhiKernel() const { return phi_kernel_; } + + const phi::InferMetaContext& InferMetaContext() const { + return infer_meta_context_; + } + + paddle::dialect::InferMetaInterface::Concept* InferMetaInterface() const { + return infer_meta_interface_; + } + + void Run() override; + + const std::string& Name() const override { return legacy_op_name_; } + + private: + std::string legacy_op_name_; + + paddle::dialect::InferMetaInterface::Concept* infer_meta_interface_{ + nullptr}; // not owned + + phi::InferMetaContext infer_meta_context_; + + paddle::framework::ExecutionContext* kernel_context_{nullptr}; + std::shared_ptr runtime_context_; + std::shared_ptr operator_base_; + + phi::Kernel* phi_kernel_{nullptr}; // not owned +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index 11c2a3814a0133..d5b7b5affc5d4b 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -32,124 +32,10 @@ #include "paddle/ir/core/operation.h" #include "paddle/ir/core/value.h" +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" namespace paddle { namespace framework { -platform::DeviceContext* ParseDeviceContext( - ir::Operation* op, - platform::DeviceContext* origin_dev_ctx, - const platform::Place& place, - const std::string& execution_stream, - const int stream_priority) { - auto op_attributes = op->attributes(); - auto op_name = - op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); - interpreter::ContextManager& ctx_manager = - interpreter::ContextManager::Instance(); - - platform::DeviceContext* dev_ctx = nullptr; - - // only gpu need update. xpu not need, because xpu memcpy op kernel is - // synchronous. - if (platform::is_gpu_place(place) || platform::is_custom_place(place)) { - VLOG(6) << "Parse DeviceContext for " << op_name - << ", execution stream = " << execution_stream; - if (execution_stream != kDefaultStream) { - dev_ctx = ctx_manager - .Get(std::string(kCustomStream) + "-" + execution_stream, - place, - stream_priority) - .get() - .get(); - interpreter::SetDeviceCommContext(op, dev_ctx); - return dev_ctx; - } - - if (op_name == interpreter::kMemcpyD2H) { - dev_ctx = ctx_manager.Get(std::string(kD2HStream), place, stream_priority) - .get() - .get(); - interpreter::SetDeviceCommContext(op, dev_ctx); - return dev_ctx; - } else if (op_name == interpreter::kMemcpyH2D) { - dev_ctx = ctx_manager.Get(std::string(kH2DStream), place, stream_priority) - .get() - .get(); - interpreter::SetDeviceCommContext(op, dev_ctx); - return dev_ctx; - } - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - // NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum - // with use_cal_stream==false by returning a device context getting from the - // global NCCLCommContext instance. Because when use_calc_stream==false, in - // OP kernel, the NCCL communication will be launched to the stream directly - // getting from the global NCCLCommContext instance rather than the - // DeviceContext passed from executor (see CAllReduceOpCUDAKernel in - // c_allreduce_op.h). Now it is just a temporary solution for ONLY - // c_allreduce_sum which is used in ResNet50 distributed training. - if (op_name == "c_allreduce_sum" && op_attributes.at("use_calc_stream") - .dyn_cast<::ir::BoolAttribute>() - .data() == false) { - int ring_id = - op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data(); - return platform::NCCLCommContext::Instance() - .Get(ring_id, place) - ->dev_context(); - } -#endif - } - - if (origin_dev_ctx != nullptr) { - interpreter::SetDeviceCommContext(op, origin_dev_ctx); - } - return origin_dev_ctx; -} - -OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) { - if (platform::is_cpu_place(place)) { - return OpFuncType::kCpuSync; - } - - auto kernel_key = op->attributes() - .at("kernel_key") - .dyn_cast() - .data(); - if (phi::TransToPhiPlace(kernel_key.backend()).GetType() == - phi::AllocationType::CPU) { - return OpFuncType::kCpuSync; - } - - PADDLE_ENFORCE_EQ(interpreter::IsSupportedHeterPlace(place), - true, - phi::errors::Fatal("Unsupported current place %s", place)); - - // Some GPU OPs do not launch CUDA Kernel, but spend a lot of time on CPU - // computing. They execute serially in device thread and block CUDA kernel - // launching in other GPU OPs. To improve performance, set them as kGpuSync - // and so that they would be dispatched to host thread. - auto op_attributes = op->attributes(); - auto op_name = - op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); - if (op_name == kCoalesceTensor && - (!platform::is_xpu_place(place) || - op->attribute("persist_output").data() == false) && - op->attribute("set_constant").data() == false && - op->attribute("copy_data").data() == false) { - return OpFuncType::kGpuSync; - } - - // for memcpy explicitly called by user - if (platform::is_gpu_place(place) && op_name == interpreter::kMemcpyD2H) { - return OpFuncType::kGpuSync; - } - - if (op_name == "shape") { - return OpFuncType::kGpuSync; - } - return OpFuncType::kGpuAsync; -} - PhiKernelInstruction::PhiKernelInstruction( size_t id, const platform::Place& place, @@ -281,78 +167,6 @@ PhiKernelInstruction::PhiKernelInstruction( VLOG(6) << "finish process no need buffer"; } -std::vector GetValueIds( - ir::Value value, - Scope* inner_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name) { - std::vector ids; - std::string var_name = value_2_var_name.at(value); - ids.push_back(var_name_2_id.at(var_name)); - // NOTE(zhangbo): Value maybe a VariableRefArray - auto var = inner_scope->FindVar(var_name); - if (var->IsType()) { - auto& var_array = var->Get(); - for (auto item : var_array) { - ids.push_back(var_name_2_id.at(variable_2_var_name.at(item))); - } - } - return ids; -} - -void PhiKernelInstruction::InitInputsOutputsIds( - ::ir::Operation* op, - Scope* inner_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name) { - std::unordered_map> inputs; - for (size_t i = 0; i < op->num_operands(); i++) { - ir::Value value = op->operand_source(i); - if (value) { - PADDLE_ENFORCE_NE( - value_2_var_name.find(value), - value_2_var_name.end(), - phi::errors::PreconditionNotMet( - "input should in name map, [%d] 'th input of [%s] op", - i, - phi_op_name_)); - std::vector inputs_id = GetValueIds(value, - inner_scope, - value_2_var_name, - var_name_2_id, - variable_2_var_name); - inputs.emplace(value, inputs_id); - } - } - SetInputs(inputs); - VLOG(8) << "finish process inputs_index"; - std::unordered_map> outputs; - for (size_t i = 0; i < op->num_results(); i++) { - ir::Value value = op->result(i); - if (value && value.type()) { - PADDLE_ENFORCE_NE( - value_2_var_name.find(value), - value_2_var_name.end(), - phi::errors::PreconditionNotMet( - "input should in name map, [%d] 'th input of [%s] op", - i, - phi_op_name_)); - std::vector outputs_id = GetValueIds(value, - inner_scope, - value_2_var_name, - var_name_2_id, - variable_2_var_name); - outputs.emplace(value, outputs_id); - } - } - SetOutputs(outputs); - VLOG(8) << "finish process outputs_index"; -} - void PhiKernelInstruction::Run() { if (infer_meta_interface_) { infer_meta_interface_->infer_meta_(&(infer_meta_context_)); diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h index b30fa8bff751b5..c637cce8651fbf 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h @@ -55,16 +55,6 @@ class PhiKernelInstruction : public InstructionBase { const std::string& Name() const override { return phi_op_name_; } private: - void InitInputsOutputsIds( - ::ir::Operation* op, - Scope* inner_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, - const std::map& var_name_2_id, - const std::unordered_map& - variable_2_var_name); - - std::string phi_op_name_; - paddle::dialect::InferMetaInterface::Concept* infer_meta_interface_{ nullptr}; // not owned @@ -73,6 +63,8 @@ class PhiKernelInstruction : public InstructionBase { phi::KernelContext kernel_context_; phi::Kernel* phi_kernel_{nullptr}; // not owned + + std::string phi_op_name_; }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 04e1457f33dcbf..9ee34fcc39c115 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -125,5 +125,6 @@ const platform::Place& InterpreterCore::GetPlace() const { void InterpreterCore::SetOutputHooks(const std::vector& hookfuncs) { impl_->SetOutputHooks(hookfuncs); } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index cf3d8c95092386..c5d31728e46b2c 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -36,6 +36,7 @@ #include "paddle/fluid/platform/flags.h" #include "paddle/phi/backends/device_manager.h" +#include "paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" #include "paddle/ir/core/builtin_attribute.h" @@ -251,7 +252,6 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, // return Fetch Tensors Scope* inner_scope = InnerScope(); - if (FLAGS_enable_new_ir_in_executor) { framework::FetchList fetch_res; @@ -1545,15 +1545,29 @@ void NewIRInterpreter::BuildInstruction() { VLOG(6) << "skip process " << op_name; continue; } - vec_instruction_base_.emplace_back( - std::make_unique(op_idx++, - place_, - op, - scope_, - local_scope_, - value_2_var_name_, - var_name_2_id_, - variable_2_var_name_)); + + if (op_name == "pd.fused_softmax_mask_upper_triangle" || + op_name == "pd.fused_softmax_mask_upper_triangle_grad") { + vec_instruction_base_.emplace_back( + std::make_unique(op_idx++, + place_, + op, + scope_, + local_scope_, + value_2_var_name_, + var_name_2_id_, + variable_2_var_name_)); + } else { + vec_instruction_base_.emplace_back( + std::make_unique(op_idx++, + place_, + op, + scope_, + local_scope_, + value_2_var_name_, + var_name_2_id_, + variable_2_var_name_)); + } } else { PADDLE_THROW(platform::errors::Unimplemented( "Now only support pd_kernel dialect.")); diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 6573848f433ef8..0e1d2de6bed29c 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -159,7 +159,6 @@ paddle::framework::FetchList StandaloneExecutor::Run( // return Fetch Tensors if (FLAGS_enable_new_ir_in_executor) { framework::FetchList fetch_res; - for (auto& var_name : fetch_var_names_) { auto* var = scope_->FindVar(var_name); fetch_res.push_back(var->Get()); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index e41004244c0de2..42b547b08d3b38 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -234,6 +234,7 @@ void HandleForSpecialOp( ->Var(fetch_var_name); var->GetMutable(); auto value = op->result(0); + AddNewData(value, fetch_var_name, var, diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index a6c2b6a4cdc33d..c211812f569bd9 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -59,6 +59,10 @@ const std::unordered_set UnchangeOutputOps = { "builtin.get_parameter", "pd.shadow_output"}; +const std::unordered_set LegacyOpList = { + "pd.fused_softmax_mask_upper_triangle", + "pd.fused_softmax_mask_upper_triangle_grad"}; + bool NeedFallBackCpu(const ir::Operation* op, const std::string& kernel_fn_name, const phi::KernelKey& kernel_key) { @@ -401,7 +405,8 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, kernel_fn_str, kernel_key); auto args_def = phi_kernel.args_def(); auto output_defs = args_def.output_defs(); - if (!UnchangeOutputOps.count(op_item->name())) { + if (!UnchangeOutputOps.count(op_item->name()) && + !LegacyOpList.count(op_item->name())) { PADDLE_ENFORCE_EQ( op_item->num_results(), output_defs.size(), @@ -413,7 +418,7 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, for (size_t i = 0; i < op_item->num_results(); ++i) { phi::Place out_place; if ((!UnchangeOutputOps.count(op_item->name())) && - phi_kernel.IsValid()) { + (!LegacyOpList.count(op_item->name())) && phi_kernel.IsValid()) { out_place = phi::TransToPhiPlace(output_defs[i].backend); } else { out_place = phi::TransToPhiPlace(kernel_key.backend()); diff --git a/test/cpp/ir/pattern_rewrite/CMakeLists.txt b/test/cpp/ir/pattern_rewrite/CMakeLists.txt index fd527db555003e..2023cc0cf413f3 100644 --- a/test/cpp/ir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/ir/pattern_rewrite/CMakeLists.txt @@ -7,3 +7,7 @@ endif() cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS ${PATTERN_REWRITE_TEST_DEPS}) + +set_tests_properties( + pattern_rewrite_test PROPERTIES ENVIRONMENT + "FLAGS_enable_new_ir_in_executor=true")