diff --git a/.gitignore b/.gitignore index 232d8fa08b4bd6..c4046a8d6b6e38 100644 --- a/.gitignore +++ b/.gitignore @@ -108,6 +108,8 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.* paddle/fluid/pir/dialect/operator/ir/op_decomp.cc paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc paddle/fluid/pir/dialect/operator/ir/pd_op.* +paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.* +paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.* paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.* paddle/fluid/pir/dialect/operator/ir/pd_op_fused.* paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.* diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index df01de6d424919..990f82efa8edeb 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -5,6 +5,16 @@ if(NOT (WITH_CINN AND NOT CINN_ONLY)) ${CMAKE_CURRENT_SOURCE_DIR}/instruction/cinn_jit_instruction.cc) endif() +if(NOT WITH_MKLDNN) + list( + REMOVE_ITEM + standalone_executor_srcs + ${CMAKE_CURRENT_SOURCE_DIR}/instruction/onednn/onednn_legacy_kernel_instruction.cc + ${CMAKE_CURRENT_SOURCE_DIR}/instruction/onednn/onednn_phi_kernel_instruction.cc + ${CMAKE_CURRENT_SOURCE_DIR}/instruction/onednn/onednn_mixed_phi_kernel_instruction.cc + ) +endif() + set(standalone_executor_deps pir program_translator diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.cc new file mode 100644 index 00000000000000..6d1944219a2dc9 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h" + +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/type_defs.h" + +namespace paddle { +namespace framework { + +OneDNNLegacyKernelInstruction::OneDNNLegacyKernelInstruction( + size_t id, + const platform::Place& place, + pir::Operation* op, + const ValueExecutionInfo* value_exec_info) + : InstructionBase(id, place), value_exec_info_(value_exec_info) { + PADDLE_THROW(platform::errors::Unimplemented( + "OneDNNLegacyKernelInstruction not defined now.")); +} + +OneDNNLegacyKernelInstruction::~OneDNNLegacyKernelInstruction() {} + +void OneDNNLegacyKernelInstruction::Run() { + PADDLE_THROW(platform::errors::Unimplemented( + "OneDNNLegacyKernelInstruction not defined now.")); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h new file mode 100644 index 00000000000000..e5c7b0cd151765 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h @@ -0,0 +1,72 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +namespace pir { +class Operation; +} // namespace pir + +namespace paddle { +namespace framework { +class Scope; +class ValueExecutionInfo; + +class OneDNNLegacyKernelInstruction : public InstructionBase { + public: + OneDNNLegacyKernelInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo* value_exec_info); + + ~OneDNNLegacyKernelInstruction(); + phi::Kernel* PhiKernel() const { return phi_kernel_; } + + const phi::InferMetaContext& InferMetaContext() const { + return infer_meta_context_; + } + + paddle::dialect::InferMetaInterface::Concept* InferMetaInterface() const { + return infer_meta_interface_; + } + + void Run() override; + + const std::string& Name() const override { return legacy_op_name_; } + + ::pir::Operation* Operation() const override { return op_; } + + private: + std::string legacy_op_name_; + + paddle::dialect::InferMetaInterface::Concept* infer_meta_interface_{ + nullptr}; // not owned + + phi::InferMetaContext infer_meta_context_; + + paddle::framework::ExecutionContext* kernel_context_{nullptr}; + std::shared_ptr runtime_context_; + std::shared_ptr operator_base_; + + phi::Kernel* phi_kernel_{nullptr}; // not owned + + ::pir::Operation* op_{nullptr}; // not owned + + const ValueExecutionInfo* value_exec_info_; // not owned +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.cc new file mode 100644 index 00000000000000..572c26eb420789 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h" + +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/type_defs.h" + +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +#include "dnnl.hpp" // NOLINT +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" +#include "paddle/phi/backends/onednn/onednn_context.h" +#include "paddle/phi/backends/onednn/onednn_helper.h" +#include "paddle/phi/kernels/funcs/data_layout_transform.h" + +namespace paddle { +namespace framework { + +OneDNNMixedPhiKernelInstruction::OneDNNMixedPhiKernelInstruction( + size_t id, + const platform::Place& place, + pir::Operation* op, + const ValueExecutionInfo* value_exec_info) + : OneDNNPhiKernelInstruction(id, place, op, value_exec_info) {} + +void OneDNNMixedPhiKernelInstruction::Run() { + // Step1. Mixed Dynamic Choose Kernel + // todo if (input_tensor.layout() != phi::DataLayout::ONEDNN) + + OneDNNPhiKernelInstruction::Run(); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h new file mode 100644 index 00000000000000..d39e5fa9d1fea0 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h" + +namespace pir { +class Operation; +} // namespace pir + +namespace paddle { +namespace framework { +class Scope; +class ValueExecutionInfo; + +using RuntimeAttribute = phi::Attribute; +using PIRAttribute = pir::Attribute; + +class OneDNNMixedPhiKernelInstruction : public OneDNNPhiKernelInstruction { + public: + OneDNNMixedPhiKernelInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo* value_exec_info); + + void Run() override; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc new file mode 100644 index 00000000000000..71385619cb958b --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc @@ -0,0 +1,388 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h" + +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/type_defs.h" + +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +#include "dnnl.hpp" // NOLINT +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" +#include "paddle/phi/backends/onednn/onednn_context.h" +#include "paddle/phi/backends/onednn/onednn_helper.h" +#include "paddle/phi/kernels/funcs/data_layout_transform.h" + +namespace paddle { +namespace framework { + +static RuntimeAttribute ConvertPirAttribute2RuntimeAttribute( + PIRAttribute attr, + const std::string& attr_name, + const paddle::dialect::OpYamlInfoParser& op_yaml_info) { + auto& attr_type_name = op_yaml_info.AttrTypeName(attr_name); + if (attr_type_name == "pir::Int32Attribute") { + return attr.dyn_cast().data(); + } else if (attr_type_name == "pir::FloatAttribute") { + return attr.dyn_cast().data(); + } else if (attr_type_name == "pir::BoolAttribute") { + return attr.dyn_cast().data(); + } else if (attr_type_name == "pir::StrAttribute") { + return attr.dyn_cast().AsString(); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr.dyn_cast().AsVector(); + std::vector vec_res; + if (array_list.size() > 0) { + PADDLE_ENFORCE_EQ(array_list[0].isa(), + true, + phi::errors::Unimplemented( + "the 0th elementwise MUST be pir::Int32Attribute")); + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back(array_list[i].dyn_cast().data()); + } + } + return vec_res; + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr.dyn_cast().AsVector(); + std::vector vec_res; + if (array_list.size() > 0) { + if (array_list[0].isa()) { + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back( + array_list[i].dyn_cast().data()); + } + + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "ConvertPirAttribute2RuntimeAttribute not support [%s] ", + attr_type_name)); + } + } + return vec_res; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "ConvertPirAttribute2RuntimeAttribute not support [%s] ", + attr_type_name)); + } +} + +void TensorNameMap(pir::Operation* op, + const ValueExecutionInfo& value_exec_info, + const paddle::dialect::OpYamlInfoParser& op_yaml_info, + std::map>& + inputs_tensor_name_map, // NOLINT + std::map>& + outputs_tensor_name_map) { // NOLINT + const Scope* inner_scope = value_exec_info.GetScope(); + VLOG(6) << "TensorNameMap in scope[" << inner_scope << "]"; + + auto& vec_kernel_fn_tensor_params = op_yaml_info.TensorParams(true); + + auto& name2id = op_yaml_info.InputName2Id(); + + std::string fluid_op_name = op_yaml_info.GetOriginOpName(); + + auto& op_normalizer = paddle::translator::OpNameNormalizer::instance(); + + for (auto& name : vec_kernel_fn_tensor_params) { + PADDLE_ENFORCE_EQ( + name2id.count(name), + true, + phi::errors::NotFound("param [%s] MUST in name2id map", name)); + auto index = name2id.at(name); + pir::Value ptr = op->operand_source(index); + + if (!IsInvalid(ptr)) { + continue; + } + + auto legacy_arg_name = op_normalizer.GetLegacyArgName(fluid_op_name, name); + auto in_var_name = value_exec_info.GetVarName(ptr); + PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), + phi::errors::PreconditionNotMet( + "can not find var[%s] in scope", in_var_name)); + + auto type = ptr.type(); + if (type.isa() || + type.isa()) { + inputs_tensor_name_map[legacy_arg_name] = {in_var_name}; + } else if (type.isa()) { + auto var = inner_scope->FindVar(in_var_name); + auto var_ref = var->Get(); + std::vector vec_tmp; + vec_tmp.reserve(var_ref.size()); + for (size_t k = 0; k < var_ref.size(); ++k) { + vec_tmp.push_back(value_exec_info.GetVarName(var_ref[k])); + } + inputs_tensor_name_map[legacy_arg_name] = vec_tmp; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "only support AllocatedDenseTensor, AllocatedSelectedRowsType and " + "pir::vector type")); + } + } + + auto& output_name_list = op_yaml_info.OutputNames(); + for (size_t i = 0; i < output_name_list.size(); ++i) { + auto name = output_name_list[i]; + pir::Value ptr = op->result(i); + auto legacy_arg_name = op_normalizer.GetLegacyArgName(fluid_op_name, name); + + if (!IsInvalid(ptr)) { + continue; + } + + auto out_var_name = value_exec_info.GetVarName(ptr); + + PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(out_var_name), + phi::errors::PreconditionNotMet( + "can not find var[%s] in scope", out_var_name)); + + auto type = ptr.type(); + if (type.isa() || + type.isa()) { + outputs_tensor_name_map[legacy_arg_name] = {out_var_name}; + } else if (type.isa()) { + auto var = inner_scope->FindVar(out_var_name); + auto var_ref = var->Get(); + std::vector vec_tmp; + vec_tmp.reserve(var_ref.size()); + for (size_t k = 0; k < var_ref.size(); ++k) { + vec_tmp.push_back(value_exec_info.GetVarName(var_ref[k])); + } + outputs_tensor_name_map[legacy_arg_name] = vec_tmp; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "only support AllocatedDenseTensor, AllocatedSelectedRowsType and " + "pir::vector type")); + } + } +} + +OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction( + size_t id, + const platform::Place& place, + pir::Operation* op, + const ValueExecutionInfo* value_exec_info) + : InstructionBase(id, place), value_exec_info_(value_exec_info) { + // Step1: build phi kernel instruction as PhiKernelInstruction + auto op_attributes = op->attributes(); + auto op_name = + op_attributes.at("op_name").dyn_cast().AsString(); + pir::OpInfo op_info = + pir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + op_ = op; + phi_op_name_ = op_name; + VLOG(6) << "construct phi kernel instruction for: " << phi_op_name_; + + SetKernelType(AnalyseOpFuncType(op, place)); + VLOG(6) << "finish process analyse kernel type"; + + infer_meta_interface_ = + op_info.GetInterfaceImpl(); + VLOG(6) << "finish process infer_meta_interface_"; + + auto yaml_interface = + op_info.GetInterfaceImpl(); + PADDLE_ENFORCE_NOT_NULL( + yaml_interface, + phi::errors::PreconditionNotMet( + "can not find OpYamlInfoInterface from [%s]", phi_op_name_)); + paddle::dialect::OpYamlInfoParser yaml_info_parser( + yaml_interface->get_op_info_(), + paddle::dialect::IsOneDNNLegacyOp(op_name)); + VLOG(6) << "finish process yaml_info_parser"; + + if (infer_meta_interface_) { + BuildPhiContext< + phi::InferMetaContext, + phi::MetaTensor, + phi::MetaTensor, + paddle::small_vector, + paddle::small_vector, + false>(op, *value_exec_info_, yaml_info_parser, &infer_meta_context_); + } + VLOG(6) << "finish process infer meta context"; + + auto kernel_name = + op_attributes.at("kernel_name").dyn_cast().AsString(); + auto kernel_key = op_attributes.at("kernel_key") + .dyn_cast() + .data(); + + phi_kernel_ = new phi::Kernel( + phi::KernelFactory::Instance().SelectKernel(kernel_name, kernel_key)); + PADDLE_ENFORCE_EQ( + phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name); + VLOG(6) << "finish process select kernel"; + + BuildPhiContext, + paddle::small_vector, + true>( + op, *value_exec_info_, yaml_info_parser, &kernel_context_); + + kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get( + phi::TransToPhiPlace(kernel_key.backend()))); + VLOG(6) << "finish process kernel context"; + + SetDeviceContext( + ParseDeviceContext(op, + phi::DeviceContextPool::Instance().Get( + phi::TransToPhiPlace(kernel_key.backend())), + place, + GetExecutionStream(), + GetStreamPriority())); + VLOG(6) << "finish process device context"; + + InitInputsOutputsIds(op, *value_exec_info); + VLOG(6) << "finish process inputs outputs index"; + + auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); + std::unordered_set no_need_buffer_values; + for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { + no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id])); + } + SetNoNeedBuffer(no_need_buffer_values); + VLOG(6) << "finish process no need buffer"; + + // Step2: build layout_transform information + if (op_attributes.count("layout_transform_arg")) { + auto layout_transform_arg = op_attributes.at("layout_transform_arg") + .dyn_cast() + .AsString(); + auto data_layout = op_attributes.at(layout_transform_arg) + .dyn_cast() + .AsString(); + input_layout_ = common::StringToDataLayout(data_layout); + std::vector layout_transform_inputs_attr = + op->attributes() + .at("layout_transform_inputs") + .dyn_cast() + .AsVector(); + std::vector layout_transform_inputs; + for (auto& attr : layout_transform_inputs_attr) { + auto pair = kernel_context_.InputRangeAt(value_exec_info_->GetIdByName( + attr.dyn_cast().AsString())); + for (int i = pair.first; i < pair.second; ++i) { + layout_transform_inputs_.insert(i); + } + } + } + + // Step3: build extra attr information + if (op_attributes.count("extra_args")) { + std::vector extra_args_attr = + op->attributes() + .at("extra_args") + .dyn_cast() + .AsVector(); + std::vector extra_args; + for (auto& attr : extra_args_attr) { + auto attr_name = attr.dyn_cast().AsString(); + extra_attr_[attr_name] = ConvertPirAttribute2RuntimeAttribute( + op_attributes.at(attr_name), attr_name, yaml_info_parser); + } + } + TensorNameMap(op, *value_exec_info_, yaml_info_parser, inputs_, outputs_); +} + +OneDNNPhiKernelInstruction::~OneDNNPhiKernelInstruction() { + if (phi_kernel_ != nullptr) { + delete phi_kernel_; + } +} + +void OneDNNPhiKernelInstruction::Run() { + // Step1. TransLayout + auto inputs = kernel_context_.InputsBetween( + size_t(0), kernel_context_.InputsSize()); + for (size_t i = 0; i < inputs.size(); ++i) { + auto input = inputs[i]; + if (input->layout() != phi::DataLayout::ONEDNN) { + phi::DataLayout from_layout = input->layout(); + + // Handle 'layout_transform' in + // ops_onednn_extra.yaml(GetKernelTypeForVar) + if (layout_transform_inputs_.count(i) && + input_layout_ != phi::DataLayout::kAnyLayout) { + from_layout = input_layout_; + } + + auto transed_tensor = const_cast(input); + + if (from_layout == DataLayout::kNHWC || + from_layout == DataLayout::kNDHWC) { + phi::funcs::MatchShapeToLayout( + transed_tensor, from_layout, phi::DataLayout::ONEDNN); + // We register only NHWC assuming that model is consistent e.g. either + // NHWC or NCHW + phi::OneDNNContext::tls().set_cur_paddle_data_layout(from_layout); + } + + if (from_layout == DataLayout::kAnyLayout) { + from_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout(); + } + + dnnl::memory::desc out_mem_desc = + phi::funcs::make_memory_desc(*input, from_layout); + transed_tensor->set_mem_desc(out_mem_desc); + } + } + + // Step2. Append extra information into ctx + // SetDnnAttrIntoDeviceContext + // SetInputsName SetOutputsName + auto one_dnn_ctx = const_cast( + &kernel_context_.GetDeviceContext()); + for (auto& attr : extra_attr_) { + one_dnn_ctx->SetDnnAttr(attr.first, attr.second); + } + one_dnn_ctx->SetInputsName(inputs_); + one_dnn_ctx->SetOutputsName(outputs_); + + // Step3. InferMeta + if (infer_meta_interface_) { + infer_meta_interface_->infer_meta_(&(infer_meta_context_)); + } + + // Step4. Run kernel + VLOG(6) << "Run op " << phi_op_name_ << " infer meta."; + (*(phi_kernel_))(&(kernel_context_)); + VLOG(6) << "Run op " << phi_op_name_ << " kernel."; + + // Step5. ClearDnnAttr + one_dnn_ctx->ClearDnnAttr(); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h new file mode 100644 index 00000000000000..c15a69728f9c3d --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +namespace pir { +class Operation; +} // namespace pir + +namespace paddle { +namespace framework { +class Scope; +class ValueExecutionInfo; + +using RuntimeAttribute = phi::Attribute; +using PIRAttribute = pir::Attribute; + +class OneDNNPhiKernelInstruction : public InstructionBase { + public: + OneDNNPhiKernelInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo* value_exec_info); + + ~OneDNNPhiKernelInstruction(); + + phi::Kernel* PhiKernel() const { return phi_kernel_; } + + const phi::KernelContext& KernelContext() const { return kernel_context_; } + + const phi::InferMetaContext& InferMetaContext() const { + return infer_meta_context_; + } + + paddle::dialect::InferMetaInterface::Concept* InferMetaInterface() const { + return infer_meta_interface_; + } + + ::pir::Operation* Operation() const override { return op_; } + + void Run() override; + + const std::string& Name() const override { return phi_op_name_; } + + private: + paddle::dialect::InferMetaInterface::Concept* infer_meta_interface_{ + nullptr}; // not owned + + phi::InferMetaContext infer_meta_context_; + + phi::KernelContext kernel_context_; + + phi::Kernel* phi_kernel_{nullptr}; // not owned + + std::string phi_op_name_; + + ::pir::Operation* op_{nullptr}; // not owned + + const ValueExecutionInfo* value_exec_info_; // not owned + + std::set layout_transform_inputs_{}; + phi::DataLayout input_layout_{phi::DataLayout::kAnyLayout}; + std::map extra_attr_{}; + std::map> inputs_{}; + std::map> outputs_{}; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 7dbb514513fc24..1cd1117d0ea1d2 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -34,6 +34,9 @@ #include "paddle/phi/core/sparse_csr_tensor.h" #ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h" #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -728,6 +731,22 @@ void PirInterpreter::BuildInstruction() { } else { CREATE_INSTR(PhiKernelInstruction); } +#ifdef PADDLE_WITH_DNNL + } else if (op.dialect()->name() == "pd_onednn_kernel") { + auto op_name = op.attributes() + .at("op_name") + .dyn_cast<::pir::StrAttribute>() + .AsString(); + VLOG(6) << "process " << op_name; + + if (op.isa()) { + CREATE_INSTR(OneDNNPhiKernelInstruction); + } else if (op.isa()) { + CREATE_INSTR(OneDNNMixedPhiKernelInstruction); + } else { + CREATE_INSTR(OneDNNLegacyKernelInstruction); + } +#endif #ifdef PADDLE_WITH_CINN } else if (op.dialect()->name() == "cinn_runtime") { CREATE_INSTR(CinnJitInstruction); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 76a787cda64bf4..626073d143e3e3 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -44,6 +44,9 @@ #include "paddle/pir/core/operation.h" #include "paddle/pir/core/value.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" +#endif // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/pir/dialect/CMakeLists.txt. #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -77,7 +80,10 @@ using AttributeHandlerFn = std::function; using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; constexpr char kTargetDialectPrefix[] = "pd_op."; // NOLINT -constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT +#ifdef PADDLE_WITH_DNNL +constexpr char kOneDNNTargetDialectPrefix[] = "pd_onednn_op."; // NOLINT +#endif +constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT static const std::unordered_set SpecialNonInplaceOps = {}; @@ -223,12 +229,36 @@ inline pir::Operation* InsertCreateArrayOp(pir::IrContext* ctx, return create_array_op.operation(); } +inline std::string GetPrefix(pir::IrContext* ctx, const OpDesc& op_desc) { +#ifdef PADDLE_WITH_DNNL + if (op_desc.GetAttrIfExists("use_mkldnn")) { + std::string target_op_name = + kOneDNNTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); + if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { + target_op_name += "_"; + } + auto op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + VLOG(3) << op_desc.Type() + << "'s use_mkldnn == True, but PIR not support OneDNN for this " + "op right now."; + return kTargetDialectPrefix; + } else { + return kOneDNNTargetDialectPrefix; + } + } else { + return kTargetDialectPrefix; + } +#else + return kTargetDialectPrefix; +#endif +} } // namespace pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) { std::string target_op_name = - kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); + GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type()); if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { target_op_name += "_"; } @@ -321,7 +351,7 @@ pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, op_desc.Type(), target_op_name); - target_op_name = kTargetDialectPrefix + target_op_name; + target_op_name = GetPrefix(ctx, op_desc) + target_op_name; if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { target_op_name += "_"; } @@ -1054,7 +1084,7 @@ struct EmbeddingGradOpTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { std::string target_op_name = - kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); + GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type()); bool is_sparse = paddle::get(op_desc.GetAttr("is_sparse")); @@ -1307,7 +1337,7 @@ struct AddNOpTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { std::string target_op_name = - kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); + GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type()); if (IsInplace(op_desc)) { target_op_name += "_"; } else { diff --git a/paddle/fluid/ir_adaptor/translator/translate.cc b/paddle/fluid/ir_adaptor/translator/translate.cc index 7a7081fe1acbf2..04ddf1d13a5a8a 100644 --- a/paddle/fluid/ir_adaptor/translator/translate.cc +++ b/paddle/fluid/ir_adaptor/translator/translate.cc @@ -22,6 +22,9 @@ #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/core/program.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" +#endif namespace paddle { using LegacyProgramDesc = ::paddle::framework::ProgramDesc; @@ -31,6 +34,9 @@ std::unique_ptr TranslateLegacyProgramToProgram( const LegacyProgramDesc& legacy_program) { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); +#ifdef PADDLE_WITH_DNNL + ctx->GetOrRegisterDialect(); +#endif auto program = std::make_unique(ctx); translator::ProgramTranslator program_translator(&legacy_program, program.get()); diff --git a/paddle/fluid/ir_adaptor/translator/utils.cc b/paddle/fluid/ir_adaptor/translator/utils.cc index ebba4428220f70..dbd85292974bf0 100644 --- a/paddle/fluid/ir_adaptor/translator/utils.cc +++ b/paddle/fluid/ir_adaptor/translator/utils.cc @@ -23,6 +23,9 @@ #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/utils.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" +#endif namespace paddle { namespace dialect { @@ -94,6 +97,9 @@ std::vector CheckUnregisteredOperationInBlock( std::vector CheckUnregisteredOperation( pir::IrContext* ctx, const framework::ProgramDesc& legacy_program) { ctx->GetOrRegisterDialect(); +#ifdef PADDLE_WITH_DNNL + ctx->GetOrRegisterDialect(); +#endif std::vector unregistered_ops; for (size_t block_idx = 0; block_idx < legacy_program.Size(); block_idx++) { diff --git a/paddle/fluid/pir/dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/CMakeLists.txt index 2c812ccada69af..337841b2274971 100644 --- a/paddle/fluid/pir/dialect/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/CMakeLists.txt @@ -27,6 +27,7 @@ set(pir_op_fwd_src_yaml ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops.yaml) set(pir_op_bwd_src_yaml ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml) + set(pir_update_op_fwd_src_yaml ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml) set(parsed_op_dir @@ -108,6 +109,44 @@ set(generated_files_pd_op "${pir_bwd_op_source_file}" "${pir_update_op_source_file}") +if(WITH_MKLDNN) + set(pir_op_onednn_yaml ${parsed_op_dir}/onednn.parsed.yaml) + + set(pd_onednn_op_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/onednn.yaml) + + set(pd_ops_onednn_extra_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml + ) + + set(op_onednn_info_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op_info.cc) + set(op_onednn_info_file_tmp ${op_onednn_info_file}.tmp) + + set(onednn_op_namespace paddle,onednn,dialect) + set(onednn_dialect_name pd_onednn_op) + set(onednn_op_header_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op.h) + set(onednn_op_source_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op.cc) + set(onednn_op_header_file_tmp ${onednn_op_header_file}.tmp) + set(onednn_op_source_file_tmp ${onednn_op_source_file}.tmp) + + execute_process( + COMMAND ${PYTHON_EXECUTABLE} ${op_gen_parsed_yaml_file} --op_yaml_path + ${pd_onednn_op_yaml_file} --output_path ${pir_op_onednn_yaml}) + + execute_process( + COMMAND + ${PYTHON_EXECUTABLE} ${op_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --namespaces + ${onednn_op_namespace} --dialect_name ${onednn_dialect_name} + --op_def_h_file ${onednn_op_header_file_tmp} --op_info_file + ${op_onednn_info_file_tmp} --op_def_cc_file ${onednn_op_source_file_tmp} + --onednn_yaml_file ${pir_op_onednn_yaml} --ops_onednn_extra_yaml_file + ${pd_ops_onednn_extra_yaml_file}) + + set(generated_files_onednn_pd_op + "${onednn_op_header_file}" "${onednn_op_source_file}" + "${op_onednn_info_file}") +endif() set(api_gen_yaml_files ${op_fwd_yaml},${op_bwd_yaml},${pir_op_fwd_yaml},${pir_op_bwd_yaml},${pir_update_op_fwd_yaml} ) @@ -159,8 +198,10 @@ execute_process( set(generated_files_ops_api "${ops_api_source_file}") -set(generated_files_pir ${generated_files_pd_op} ${generated_files_pd_api} - ${generated_files_python_c} ${generated_files_ops_api}) +set(generated_files_pir + ${generated_files_pd_op} ${generated_files_onednn_pd_op} + ${generated_files_pd_api} ${generated_files_python_c} + ${generated_files_ops_api}) foreach(generated_file ${generated_files_pir}) if(EXISTS "${generated_file}.tmp" AND EXISTS "${generated_file}") execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different @@ -206,6 +247,10 @@ set(op_dialect_srcs ${pir_update_op_source_file} ${api_source_file}) +if(WITH_MKLDNN) + set(op_dialect_srcs ${op_dialect_srcs} ${onednn_op_source_file}) +endif() + set(op_dialect_deps phi common pir type_info string_helper) cc_library( @@ -222,6 +267,13 @@ set(op_dialect_vjp_srcs ${op_decomp_source_file} ${op_vjp_source_file} ${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/base/decomp_trans.cc) + +if(WITH_MKLDNN) + set(op_dialect_vjp_srcs + ${op_dialect_vjp_srcs} + ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_onednn_dialect.cc) +endif() + set(op_dialect_vjp_deps primitive_vjp_experimental op_dialect) cc_library( diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc index 95e77ff6169c68..ecf04d4411397b 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc @@ -122,7 +122,110 @@ void KernelDialect::PrintOperation(pir::Operation *op, } } +#ifdef PADDLE_WITH_DNNL +OneDNNKernelDialect::OneDNNKernelDialect(pir::IrContext *context) + : pir::Dialect(name(), context, pir::TypeId::get()) { + initialize(); +} + +void OneDNNKernelDialect::initialize() { + RegisterTypes(); + RegisterOps(); + RegisterAttributes(); +} + +void OneDNNKernelDialect::PrintType(pir::Type type, std::ostream &os) const { + if (type.isa()) { + AllocatedDenseTensorType tensor_type = + type.dyn_cast(); + + os << phi::AllocationTypeStr(tensor_type.place().GetType()) << "_"; + os << "tensor<"; + for (auto d : common::vectorize(tensor_type.dims())) { + os << d; + os << "x"; + } + tensor_type.dtype().Print(os); + os << ">"; + } else if (type.isa()) { + AllocatedSelectedRowsType tensor_type = + type.dyn_cast(); + + os << phi::AllocationTypeStr(tensor_type.place().GetType()) << "_"; + os << "tensor<"; + for (auto d : common::vectorize(tensor_type.dims())) { + os << d; + os << "x"; + } + tensor_type.dtype().Print(os); + os << ">"; + } else if (type.isa()) { + AllocatedDenseTensorArrayType tensor_array_type = + type.dyn_cast(); + + os << phi::AllocationTypeStr(tensor_array_type.place().GetType()) << "_"; + os << "tensor_array<"; + tensor_array_type.dtype().Print(os); + os << ">"; + } +} + +void OneDNNKernelDialect::PrintAttribute(pir::Attribute attr, + std::ostream &os) const { + phi::KernelKey kernel = attr.dyn_cast().data(); + + os << ""; +} + +void OneDNNKernelDialect::PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const { + if (op->dyn_cast() || op->dyn_cast()) { + auto &os = printer.os; + printer.PrintOpResult(op); + os << " ="; + if (auto phi_kernel_op = op->dyn_cast()) { + std::string kernel_name = phi_kernel_op.kernel_name(); + if (op->attributes().count("is_inplace") != 0 && + op->attributes() + .at("is_inplace") + .dyn_cast() + .data()) { + kernel_name = kernel_name + "_"; + } + os << " \"" << kernel_name << "(phi_kernel)\""; + } else { + auto legacy_kernel_op = op->dyn_cast(); + std::string kernel_name = legacy_kernel_op.kernel_name(); + if (op->attributes().count("is_inplace") != 0 && + op->attributes() + .at("is_inplace") + .dyn_cast() + .data()) { + kernel_name = kernel_name + "_"; + } + os << " \"" << kernel_name << "(legacy_kernel)\""; + } + printer.PrintOpOperands(op); + printer.PrintAttributeMap(op); + os << " :"; + printer.PrintOperandsType(op); + os << " -> "; + printer.PrintOpReturnType(op); + } else { + printer.PrintGeneralOperation(op); + } +} +#endif + } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect) +#ifdef PADDLE_WITH_DNNL +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNKernelDialect) +#endif diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h index d2fbcadaf8cf2a..fbdb53a40b183d 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h @@ -36,7 +36,29 @@ class KernelDialect : public pir::Dialect { void initialize(); }; +#ifdef PADDLE_WITH_DNNL +class OneDNNKernelDialect : public pir::Dialect { + public: + explicit OneDNNKernelDialect(pir::IrContext* context); + + static const char* name() { return "pd_onednn_kernel"; } + + void PrintType(pir::Type type, std::ostream& os) const override; + + void PrintAttribute(pir::Attribute attr, std::ostream& os) const override; + + void PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const override; // NOLINT + + private: + void initialize(); +}; +#endif + } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect) +#ifdef PADDLE_WITH_DNNL +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNKernelDialect) +#endif diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc index 8ad46bc8906adb..45f0a848fc174d 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc @@ -98,8 +98,135 @@ phi::KernelKey LegacyKernelOp::kernel_key() { return attributes().at("kernel_key").dyn_cast().data(); } +#ifdef PADDLE_WITH_DNNL +const char* OneDNNPhiKernelOp::attributes_name[attributes_num] = { // NOLINT + "op_name", + "kernel_name", + "kernel_key"}; + +void OneDNNPhiKernelOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: OneDNNPhiKernelOp."; + + auto& attributes = this->attributes(); + + PADDLE_ENFORCE(attributes.count("op_name") > 0 && + attributes.at("op_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: op_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_name") > 0 && + attributes.at("kernel_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_key") > 0 && + attributes.at("kernel_key").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_key is not right.")); +} + +std::string OneDNNPhiKernelOp::op_name() { + return attributes().at("op_name").dyn_cast().AsString(); +} +std::string OneDNNPhiKernelOp::kernel_name() { + return attributes() + .at("kernel_name") + .dyn_cast() + .AsString(); +} +phi::KernelKey OneDNNPhiKernelOp::kernel_key() { + return attributes().at("kernel_key").dyn_cast().data(); +} + +const char* OneDNNMixedPhiKernelOp::attributes_name[attributes_num] = + { // NOLINT + "op_name", + "kernel_name", + "kernel_key"}; + +void OneDNNMixedPhiKernelOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: " + "OneDNNMixedPhiKernelOp."; + + auto& attributes = this->attributes(); + + PADDLE_ENFORCE(attributes.count("op_name") > 0 && + attributes.at("op_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: op_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_name") > 0 && + attributes.at("kernel_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_key") > 0 && + attributes.at("kernel_key").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_key is not right.")); +} + +std::string OneDNNMixedPhiKernelOp::op_name() { + return attributes().at("op_name").dyn_cast().AsString(); +} +std::string OneDNNMixedPhiKernelOp::kernel_name() { + return attributes() + .at("kernel_name") + .dyn_cast() + .AsString(); +} +phi::KernelKey OneDNNMixedPhiKernelOp::kernel_key() { + return attributes().at("kernel_key").dyn_cast().data(); +} + +const char* OneDNNLegacyKernelOp::attributes_name[attributes_num] = { // NOLINT + "op_name", + "kernel_name", + "kernel_key"}; + +void OneDNNLegacyKernelOp::VerifySig() { + VLOG(4) + << "Verifying inputs, outputs and attributes for: OneDNNLegacyKernelOp."; + + auto& attributes = this->attributes(); + + PADDLE_ENFORCE(attributes.count("op_name") > 0 && + attributes.at("op_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: op_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_name") > 0 && + attributes.at("kernel_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_key") > 0 && + attributes.at("kernel_key").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_key is not right.")); +} + +std::string OneDNNLegacyKernelOp::op_name() { + return attributes().at("op_name").dyn_cast().AsString(); +} +std::string OneDNNLegacyKernelOp::kernel_name() { + return attributes() + .at("kernel_name") + .dyn_cast() + .AsString(); +} +phi::KernelKey OneDNNLegacyKernelOp::kernel_key() { + return attributes().at("kernel_key").dyn_cast().data(); +} +#endif + } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp) +#ifdef PADDLE_WITH_DNNL +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNPhiKernelOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNMixedPhiKernelOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNLegacyKernelOp) +#endif diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h index a96aa5732d5806..df723158702085 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h @@ -44,8 +44,51 @@ class LegacyKernelOp : public pir::Op { void VerifySig(); }; +#ifdef PADDLE_WITH_DNNL +class OneDNNPhiKernelOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_onednn_kernel.phi_kernel"; } + static constexpr uint32_t attributes_num = 3; + static const char *attributes_name[attributes_num]; + std::string op_name(); + std::string kernel_name(); + phi::KernelKey kernel_key(); + void VerifySig(); +}; + +class OneDNNMixedPhiKernelOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_onednn_kernel.phi_mixed_kernel"; } + static constexpr uint32_t attributes_num = 3; + static const char *attributes_name[attributes_num]; + std::string op_name(); + std::string kernel_name(); + phi::KernelKey kernel_key(); + void VerifySig(); +}; + +class OneDNNLegacyKernelOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_onednn_kernel.legacy_kernel"; } + static constexpr uint32_t attributes_num = 3; + static const char *attributes_name[attributes_num]; + std::string op_name(); + std::string kernel_name(); + phi::KernelKey kernel_key(); + void VerifySig(); +}; +#endif + } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp) +#ifdef PADDLE_WITH_DNNL +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNPhiKernelOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNMixedPhiKernelOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNLegacyKernelOp) +#endif diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 8e56406583385c..f5c2a8c43e775a 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -30,6 +30,7 @@ from op_kerneltype_gen import gen_kernel_type_for_var_str from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str +from ops_onednn_extra_parser import parse_extra_args, parse_layout_transform from parse_kernel_key_gen import gen_parse_kernel_key_str from vjp_interface_black_list import vjp_interface_black_list @@ -63,6 +64,7 @@ #include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h" #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" +#include "paddle/fluid/pir/dialect/operator/trait/onednn.h" #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/core/infermeta_utils.h" @@ -209,6 +211,17 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); }} """ + +OP_INFO_ONEDNN_TEMPLATE = """ +OpInfoTuple {op_name}::GetOpInfo() {{ + std::vector inputs = {{ {inputs} }}; + std::vector attributes = {{ {attributes} }}; + std::vector outputs = {{ {outputs} }}; + paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, "{layout_transform_arg}", {{{layout_transform_inputs}}}, {is_onednn_only}, {dynamic_fallback}); + return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); +}} +""" + CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute}, {with_grad_semantic})""" CONSTRUCT_OUTPUT_INFO_TEMPLATE = """paddle::dialect::OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = """paddle::dialect::OpAttributeInfo("{name}", "{typename}", "{data_type}")""" @@ -416,7 +429,7 @@ def __init__(self, op_yaml_item, op_compat_item): self.non_mutable_attribute_data_type_list, self.non_mutable_attribute_build_arg_type_list, self.non_mutable_attribute_default_value_list, - ) = self.parse_non_nutable_attribute() + ) = self.parse_non_mutable_attribute() # parse infermeta && kernel self.infer_meta_map = self.parse_infer_meta_map() @@ -458,6 +471,18 @@ def __init__(self, op_yaml_item, op_compat_item): # parse interfaces list self.interfaces_list = self.parse_op_interfaces() + # OneDNN info + if "extra_args" in self.op_yaml_item: + self.onednn_extra_args = self.op_yaml_item["extra_args"] + self.onednn_layout_transform = self.op_yaml_item["layout_transform"] + self.is_onednn_only = self.op_yaml_item["is_onednn_only"] + self.dynamic_fallback = self.op_yaml_item["dynamic_fallback"] + else: + self.onednn_extra_args = [] + self.onednn_layout_transform = None + self.is_onednn_only = False + self.dynamic_fallback = False + def parse_op_traits(self): if 'traits' in self.op_yaml_item: return self.op_yaml_item['traits'] @@ -629,7 +654,7 @@ def parse_mutable_attribute(self): sorted_mutable_attribute_type_list, ) - def parse_non_nutable_attribute(self): + def parse_non_mutable_attribute(self): op_non_mutable_attribute_name_list = [] op_non_mutable_attribute_type_list = [] op_non_mutable_attribute_data_type_list = [] @@ -1108,17 +1133,21 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): if ( op_info.backward_name and op_info.op_phi_name[0] not in vjp_interface_black_list + and dialect_name != "pd_onednn_op" ): op_interfaces += ["paddle::dialect::VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str( op_info, op_info_items ) - if dialect_name == "pd_op": + if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": op_interfaces += ["paddle::dialect::GetKernelTypeForVarInterface"] # if op has custom vjp rule, then append a CustomVjpTrait to it - if op_info.op_phi_name[0] in custom_vjp_op_name_list: + if ( + op_info.op_phi_name[0] in custom_vjp_op_name_list + and dialect_name != "pd_onednn_op" + ): op_traits += ["paddle::dialect::CustomVjpTrait"] # check op inputs and mutable_attributes grad semantics @@ -1139,6 +1168,15 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): if op_name[-1] == "_": op_traits += ["paddle::dialect::InplaceTrait"] + if dialect_name == "pd_onednn_op": + op_traits += ["paddle::dialect::OneDNNTrait"] + + if op_info.is_onednn_only: + op_traits += ["paddle::dialect::OneDNNOnlyTrait"] + + if op_info.dynamic_fallback: + op_traits += ["paddle::dialect::OneDNNDynamicFallbackTrait"] + op_traits_str = "" if len(op_traits) > 0: op_traits_str = "," + ",".join(op_traits) @@ -1154,6 +1192,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): if ( op_name in decomp_interface_declare_gen_op_list and kernel_func_name in decomp_interface_declare_gen_op_list + and dialect_name != "pd_onednn_op" ): op_interfaces = op_interfaces + [ "paddle::dialect::DecompInterface" @@ -1217,7 +1256,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): build_func_with_muta_attr_is_input = "" get_kernel_type_for_var_declare_str = "" - if dialect_name == "pd_op": + if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": get_kernel_type_for_var_declare_str = ( get_kernel_type_for_var_declare_template ) @@ -1547,6 +1586,53 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): origin_op_name=op_info.op_yaml_item['name'], ) + if dialect_name == "pd_onednn_op": + if len(op_info.onednn_extra_args) > 0: + args_name = [] + for arg in op_info.onednn_extra_args: + args_name.append(arg["name"]) + + extra_args = '"' + '", "'.join(args_name) + '"' + else: + extra_args = "" + if op_info.onednn_layout_transform is None: + layout_transform_arg, layout_transform_inputs = ( + "", + "", + ) + else: + ( + layout_transform_arg, + layout_transform_inputs, + ) = op_info.onednn_layout_transform + layout_transform_inputs = ( + '"' + '", "'.join(layout_transform_inputs) + '"' + ) + + op_info_func_str = OP_INFO_ONEDNN_TEMPLATE.format( + op_name=op_class_name, + inputs=inputs_info_str, + attributes=attribute_info_str, + outputs=outputs_info_str, + infer_meta_func=infer_meta_func_str, + infer_meta_param=infer_meta_param_str, + kernel_func=kernel_func_str, + kernel_param=kernel_param_str, + kernel_key_dtype=kernel_key_dtype, + kernel_key_backend=kernel_key_backend, + inplace=inplace_str, + view=view_str, + origin_op_name=op_info.op_yaml_item['name'], + extra_args=extra_args, + layout_transform_arg=layout_transform_arg, + layout_transform_inputs=layout_transform_inputs, + is_onednn_only="true" + if op_info.is_onednn_only + else "false", + dynamic_fallback="true" + if op_info.dynamic_fallback + else "false", + ) # generate op verify function str op_verify_str = '' if not op_info.custom_verify: @@ -1591,7 +1677,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): # generate op GetKernelKeyForVar function str op_get_kernel_type_for_var_str = '' - if dialect_name == "pd_op": + if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": op_get_kernel_type_for_var_str = ( gen_kernel_type_for_var_str( op_class_name, @@ -1620,6 +1706,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): op_info.backward_name and op_info.op_phi_name[0] not in vjp_interface_black_list + and dialect_name != "pd_onednn_op" ): op_vjp_str = gen_op_vjp_str( op_class_name, @@ -1650,7 +1737,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): ops_defined_list.append(infer_symbolic_shape_define_str) # NOTE(chenxi67)skip if dialect_name==cinn - if dialect_name == "cinn": + if dialect_name == "cinn" or dialect_name == "pd_onednn_op": pass else: ops_vjp_defined_list.append(op_vjp_str) @@ -1732,6 +1819,8 @@ def OpGenerator( op_info_file, op_def_cc_file, op_vjp_cc_file, + onednn_yaml_file, + ops_onednn_extra_yaml_file, ): # (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp if os.path.exists(op_def_h_file): @@ -1745,8 +1834,32 @@ def OpGenerator( # (2) parse yaml files op_compat_parser = OpCompatParser(op_compat_yaml_file) + if dialect_name == "pd_onednn_op": + with open(ops_onednn_extra_yaml_file, "r") as f: + ops_onednn_extra = yaml.safe_load(f) + ops_onednn_extra_map = {} + for op in ops_onednn_extra: + op_name = op['op'] + item = {} + item["is_onednn_only"] = False + item["extra_args"] = parse_extra_args(op_name, op['extra_args']) + if 'layout_transform' in op: + item["layout_transform"] = parse_layout_transform( + op_name, op['layout_transform'] + ) + else: + item["layout_transform"] = None + if 'dynamic_fallback' in op: + item["dynamic_fallback"] = op['dynamic_fallback'] + else: + item["dynamic_fallback"] = False + item["attrs"] = parse_extra_args(op_name, op['extra_args']) + ops_onednn_extra_map[op_name] = item + op_yaml_files.insert(0, onednn_yaml_file) + op_infos = [] all_op_info_items = {} + first_file = True for yaml_file in op_yaml_files: op_yaml_items = [] with open(yaml_file, "r") as f: @@ -1756,7 +1869,7 @@ def OpGenerator( op_info_items = {} for op in op_yaml_items: op_compat_item = None - if dialect_name == "pd_op": + if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": op_compat_item = op_compat_parser.get_compat(op['name']) if ( @@ -1782,11 +1895,26 @@ def OpGenerator( ) = op_compat_parser.parse_support_tensor(op) op_compat_item['scalar'] = scalar_item op_compat_item['int_array'] = int_array_item - - op_info_items[op['name']] = OpInfoParser(op, op_compat_item) - all_op_info_items[op['name']] = OpInfoParser(op, op_compat_item) + if dialect_name == "pd_onednn_op": + if first_file: + first_file = False + op["is_onednn_only"] = True + elif op['name'] in ops_onednn_extra_map: + onednn_item = ops_onednn_extra_map[op['name']] + op["is_onednn_only"] = onednn_item["is_onednn_only"] + op["extra_args"] = onednn_item["extra_args"] + op["layout_transform"] = onednn_item["layout_transform"] + op["dynamic_fallback"] = onednn_item["dynamic_fallback"] + op["attrs"] = op["attrs"] + onednn_item["attrs"] + else: + continue + item = OpInfoParser(op, op_compat_item) + op_info_items[op['name']] = item + all_op_info_items[op['name']] = item op_infos.append(op_info_items) + if dialect_name == "pd_onednn_op": + op_infos = [all_op_info_items] # (3) auto code gen op_list_strs = [] @@ -1858,14 +1986,15 @@ def OpGenerator( else: op_to_multi_kernels_map_str = "" - op_info_str = CC_OP_INFO_FILE_TEMPLATE.format( - op_declare=",".join(op_list_strs).replace("\n", ""), - op_to_multi_kernels_map=op_to_multi_kernels_map_str, - h_file=op_def_h_file[:-4], - ) + if op_info_file is not None: + op_info_str = CC_OP_INFO_FILE_TEMPLATE.format( + op_declare=",".join(op_list_strs).replace("\n", ""), + op_to_multi_kernels_map=op_to_multi_kernels_map_str, + h_file=op_def_h_file[:-4], + ) - with open(op_info_file, 'w') as f: - f.write(op_info_str) + with open(op_info_file, 'w') as f: + f.write(op_info_str) # (6) write to files for xx_op.cc.tmp for id in range(len(op_def_cc_file)): @@ -1874,8 +2003,17 @@ def OpGenerator( source_file_str = NAMESPACE_GARD_TEMPLATE.format( namespace=name, input=source_file_str ) # Add namespaces + + if dialect_name == "pd_onednn_op": + op_def_h_file_tmp = ( + "paddle/fluid/pir/dialect/operator/ir/pd_op.h\"\n#include \"" + + op_def_h_file + ) + else: + op_def_h_file_tmp = op_def_h_file + source_file_str = CC_FILE_TEMPLATE.format( - h_file=op_def_h_file[:-4], + h_file=op_def_h_file_tmp[:-4], input=source_file_str, define_type_id=define_type_id_strs[id], ) @@ -1887,7 +2025,11 @@ def OpGenerator( # and vjp is only avaible for pd dialect. vjp_source_file_str = "\n".join(vjp_source_file_strs) vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format(input=vjp_source_file_str) - if dialect_name != 'cinn' and op_vjp_cc_file: + if ( + dialect_name != 'cinn' + and dialect_name != 'pd_onednn_op' + and op_vjp_cc_file + ): with open(op_vjp_cc_file, 'w') as f: f.write(vjp_source_file_str) @@ -1907,6 +2049,8 @@ def ParseArguments(): parser.add_argument('--op_info_file', type=str) parser.add_argument('--op_def_cc_file', type=str) parser.add_argument('--op_vjp_cc_file', type=str) + parser.add_argument('--onednn_yaml_file', type=str) + parser.add_argument('--ops_onednn_extra_yaml_file', type=str) return parser.parse_args() @@ -1926,6 +2070,8 @@ def ParseArguments(): op_info_file = args.op_info_file op_def_cc_files = args.op_def_cc_file.split(",") op_vjp_cc_file = args.op_vjp_cc_file + onednn_yaml_file = args.onednn_yaml_file + ops_onednn_extra_yaml_file = args.ops_onednn_extra_yaml_file # auto code generate OpGenerator( @@ -1937,4 +2083,6 @@ def ParseArguments(): op_info_file, op_def_cc_files, op_vjp_cc_file, + onednn_yaml_file, + ops_onednn_extra_yaml_file, ) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index d379bedaab6437..9e861aa26ea8dd 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -103,6 +103,7 @@ 'sequence_mask', 'number_count', 'assign_value', + 'onednn_to_paddle_layout', ] NO_NEED_GEN_STATIC_ONLY_APIS = [ diff --git a/paddle/fluid/pir/dialect/op_generator/ops_onednn_extra_parser.py b/paddle/fluid/pir/dialect/op_generator/ops_onednn_extra_parser.py new file mode 100644 index 00000000000000..3296fa0d68829d --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/ops_onednn_extra_parser.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, List, Tuple + + +def parse_plain_list(s: str, sep=",") -> List[str]: + if sep == ",": + patten = re.compile(r',(?![^{]*\})') # support "int[] a={1,2}" + items = re.split(patten, s.strip()) + items = [x.strip() for x in items] + return items + else: + return [item.strip() for item in s.strip().split(sep)] + + +def parse_arg(op_name: str, s: str) -> Dict[str, str]: + """parse an argument in following formats: + 1. typename name + 2. typename name = default_value + """ + typename, rest = (item.strip() for item in s.split(" ", 1)) + assert ( + len(typename) > 0 + ), f"The arg typename should not be empty. Please check the args of {op_name} in yaml." + + assert ( + rest.count("=") <= 1 + ), f"There is more than 1 = in an arg in {op_name}" + if rest.count("=") == 1: + name, default_value = (item.strip() for item in rest.split("=", 1)) + assert ( + len(name) > 0 + ), f"The arg name should not be empty. Please check the args of {op_name} in yaml." + assert ( + len(default_value) > 0 + ), f"The default value should not be empty. Please check the args of {op_name} in yaml." + return { + "typename": typename, + "name": name, + "default_value": default_value, + } + else: + name = rest.strip() + assert ( + len(name) > 0 + ), f"The arg name should not be empty. Please check the args of {op_name} in yaml." + return {"typename": typename, "name": name} + + +def parse_extra_args(op_name: str, arguments: str) -> List: + if arguments is None: + return [] + args_str = arguments.strip() + args = parse_plain_list(args_str) + + attrs = [] + + for arg in args: + item = parse_arg(op_name, arg) + typename = item["typename"] + name = item["name"] + attrs.append(item) + return attrs + + +def parse_layout_transform( + op_name: str, layout_transform: Dict[str, Any] +) -> Tuple[str, List]: + if layout_transform is None: + return "", [] + return layout_transform["arg_name"], parse_plain_list( + layout_transform["tensors"] + ) diff --git a/paddle/fluid/pir/dialect/operator/ir/onednn.yaml b/paddle/fluid/pir/dialect/operator/ir/onednn.yaml new file mode 100644 index 00000000000000..d7de4310d5781f --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/onednn.yaml @@ -0,0 +1,9 @@ +- op : quantize + args : (Tensor input, bool is_negative_input=false, float scale=1.0, float shift=0.0, str output_format="NHWC", bool bfloat16=false) + output : Tensor(output) + infer_meta : + func : UnchangedInferMeta + param : [input] + kernel : + func : quantize + data_type : input diff --git a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc new file mode 100644 index 00000000000000..0d65389cc4922b --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" +#include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h" +#include "paddle/pir/core/builtin_type_interfaces.h" +#include "paddle/pir/core/interface_value.h" +#include "paddle/pir/core/ir_printer.h" +#include "paddle/pir/core/utils.h" +#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" + +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" +#endif + +namespace paddle { +namespace dialect { + +OneDNNOperatorDialect::OneDNNOperatorDialect(pir::IrContext *ctx) + : pir::Dialect(name(), ctx, pir::TypeId::get()) { + initialize(); +} + +void OneDNNOperatorDialect::initialize() { + // NOTE(zhangbo9674): GET_OP_LIST is defined in pd_op.h which is + // generated by op_gen.py, see details in + // paddle/fluid/pir/dialect/CMakeLists.txt. + // NOTE(Ruting)GET_MANUAL_OP_LIST is define in manual_op.h" + // use RegisterOps when list has more than two ops. + RegisterOps< +#define GET_OP_LIST +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.cc" // NOLINT + >(); +} + +void OneDNNOperatorDialect::PrintType(pir::Type type, std::ostream &os) const { + os << type.dialect().name(); + os << '.'; + if (auto tensor_type = type.dyn_cast()) { + os << "tensor<"; + for (auto d : common::vectorize(tensor_type.dims())) { + os << d; + os << "x"; + } + tensor_type.dtype().Print(os); + os << ">"; + } else if (auto selected_rows_type = type.dyn_cast()) { + os << "selectedrows<"; + for (auto d : common::vectorize(selected_rows_type.dims())) { + os << d; + os << "x"; + } + selected_rows_type.dtype().Print(os); + os << ">"; + } else if (auto tensor_array_type = type.dyn_cast()) { + os << "tensor_array<"; + tensor_array_type.dtype().Print(os); + os << ">"; + } +} + +void OneDNNOperatorDialect::PrintAttribute(pir::Attribute attr, + std::ostream &os) const { + os << "(" << attr.dialect().name(); + os << '.'; + if (auto int_array_attr = attr.dyn_cast()) { + phi::IntArray data = int_array_attr.data(); + os << "IntArray)" + << "["; + const auto &inner_data = data.GetData(); + pir::PrintInterleave( + inner_data.begin(), + inner_data.end(), + [&os](int64_t i) { os << i; }, + [&os]() { os << ","; }); + os << "]"; + } else if (auto data_type_attr = attr.dyn_cast()) { + os << "DataType)" << data_type_attr.data(); + } else if (auto place_type_attr = attr.dyn_cast()) { + os << "Place)" << place_type_attr.data(); + } else if (auto data_layout_attr = attr.dyn_cast()) { + os << "DataLayout)" << data_layout_attr.data(); + } else { + os << "<#AttrNotImplemented>"; + } +} + +pir::Type OneDNNOperatorDialect::ParseType(pir::IrParser &parser) { // NOLINT + parser.ConsumeAToken("pd_op.tensor"); + parser.ConsumeAToken("<"); + std::vector dim{}; + Token dim_token = parser.PeekToken(); + while (dim_token.token_type_ == DIGIT) { + dim_token = parser.ConsumeToken(); + dim.push_back(atoi(dim_token.val_.c_str())); + std::string peek_token_val = parser.PeekToken().val_; + if (peek_token_val[0] != 'x') { + break; + } + parser.ConsumeToken(); + parser.lexer->Unget(static_cast(peek_token_val.size() - 1)); + if (parser.PeekToken().token_type_ != DIGIT) { + break; + } + } + phi::DDim ddim = common::make_ddim(dim); + pir::Type dtype = parser.ParseType(); + std::vector> lod; + std::vector lodv; + lodv.push_back(0); + lod.push_back(lodv); + parser.ConsumeAToken(">"); + return DenseTensorType::get( + parser.ctx, dtype, ddim, phi::DataLayout::UNDEFINED, lod, 0); +} + +pir::Attribute OneDNNOperatorDialect::ParseAttribute( + pir::IrParser &parser) { // NOLINT + std::string type_name = parser.ConsumeToken().val_; + std::string attribute_name = + type_name.substr(type_name.find('.') + 1, std::string::npos); + parser.ConsumeAToken(")"); + if (attribute_name == "IntArray") { + return IntArrayAttribute::Parse(parser); + } else if (attribute_name == "DataType") { + return DataTypeAttribute::Parse(parser); + } else if (attribute_name == "Place") { + return PlaceAttribute::Parse(parser); + } else if (attribute_name == "DataLayout") { + return DataLayoutAttribute::Parse(parser); + } else { + IR_THROW("No function to parse " + attribute_name + " exists!" + + parser.GetErrorLocationInfo()); + } +} + +void OneDNNOperatorDialect::PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const { + if (auto if_op = op->dyn_cast()) { + if_op.Print(printer); + } else if (auto while_op = op->dyn_cast()) { + while_op.Print(printer); + } else { + printer.PrintGeneralOperation(op); + } +} + +} // namespace dialect +} // namespace paddle + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNOperatorDialect) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h new file mode 100644 index 00000000000000..ac6483d4d53ecb --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/dialect.h" + +namespace paddle { +namespace dialect { + +class OneDNNOperatorDialect : public pir::Dialect { + public: + explicit OneDNNOperatorDialect(pir::IrContext* context); + + static const char* name() { return "pd_onednn_op"; } + + pir::Type ParseType(pir::IrParser& parser) override; // NOLINT + pir::Attribute ParseAttribute(pir::IrParser& parser) override; // NOLINT + + void PrintType(pir::Type type, std::ostream& os) const override; + void PrintAttribute(pir::Attribute type, std::ostream& os) const override; + + void PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const override; // NOLINT + + private: + void initialize(); +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNOperatorDialect) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 57d7857a2498ce..0d571f8ef868a7 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1464,6 +1464,15 @@ func: number_count data_type: numbers +- op: onednn_to_paddle_layout + args: (Tensor x, int dst_layout) + output: Tensor(out) + infer_meta: + func : UnchangedInferMeta + param : [x] + kernel: + func: onednn_to_paddle_layout + - op: sparse_momentum args: (Tensor param, Tensor grad, Tensor velocity, Tensor index, Tensor learning_rate, Tensor master_param,float mu, Scalar axis=0, bool use_nesterov=false,str regularization_method="", float regularization_coeff=0.0f, bool multi_precision=false, float rescale_grad=1.0f) output: Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml new file mode 100644 index 00000000000000..58897216793dd6 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml @@ -0,0 +1,33 @@ + +- op : conv2d + extra_args : bool is_test=false + layout_transform : + arg_name: data_format + tensors: input + +- op : conv2d_grad + extra_args : bool is_test=false + layout_transform : + arg_name: data_format + tensors: input, out_grad +# - op : matmul +# extra_args : str mkldnn_data_type="float32" +# layout_transform : +# arg_name: cur_paddle_data_layout +# tensors: x, y + +# - op : pad3d +# extra_args : +# layout_transform : +# arg_name: data_format +# tensors: x +# dynamic_fallback : True + +# - op : batch_norm +# extra_args : bool fuse_with_relu=false +# layout_transform : +# arg_name: data_layout +# tensors: x + +# - op : prelu +# extra_args : bool is_test=false, str mkldnn_data_type="float32" diff --git a/paddle/fluid/pir/dialect/operator/trait/onednn.h b/paddle/fluid/pir/dialect/operator/trait/onednn.h new file mode 100644 index 00000000000000..df810c6707df12 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/trait/onednn.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_DNNL + +#include "paddle/pir/core/op_base.h" + +namespace paddle { +namespace dialect { +class OneDNNTrait : public pir::OpTraitBase { + public: + explicit OneDNNTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} +}; + +class OneDNNOnlyTrait : public pir::OpTraitBase { + public: + explicit OneDNNOnlyTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} +}; + +class OneDNNDynamicFallbackTrait + : public pir::OpTraitBase { + public: + explicit OneDNNDynamicFallbackTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNOnlyTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNDynamicFallbackTrait) + +#endif diff --git a/paddle/fluid/pir/dialect/operator/trait/trait.cc b/paddle/fluid/pir/dialect/operator/trait/trait.cc index 2a5b7575959b9f..9d828570d389aa 100644 --- a/paddle/fluid/pir/dialect/operator/trait/trait.cc +++ b/paddle/fluid/pir/dialect/operator/trait/trait.cc @@ -14,6 +14,14 @@ #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" - +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/trait/onednn.h" +#endif IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InplaceTrait) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomVjpTrait) + +#ifdef PADDLE_WITH_DNNL +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNOnlyTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNDynamicFallbackTrait) +#endif diff --git a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h index 637de470675eb1..662616bce773a0 100644 --- a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h @@ -93,6 +93,12 @@ struct OpRunTimeInfo { std::vector kernel_key_backend; std::vector> inplace; std::vector> view; + std::vector extra_args; + std::string layout_transform_arg; + std::vector layout_transform_inputs; + bool is_onednn_only; + bool dynamic_fallback; + OpRunTimeInfo(const std::string& infer_meta_func, const std::vector& infer_meta_param, const std::string& kernel_func, @@ -100,7 +106,12 @@ struct OpRunTimeInfo { const std::vector& dtype, const std::vector& backend, const std::vector>& inplace, - const std::vector>& view) + const std::vector>& view, + const std::vector& extra_args = {}, + const std::string& layout_transform_arg = "", + const std::vector& layout_transform_inputs = {}, + bool is_onednn_only = false, + bool dynamic_fallback = false) : infer_meta_func(infer_meta_func), infer_meta_param(infer_meta_param), kernel_func(kernel_func), @@ -108,7 +119,12 @@ struct OpRunTimeInfo { kernel_key_dtype(dtype), kernel_key_backend(backend), inplace(inplace), - view(view) {} + view(view), + extra_args(extra_args), + layout_transform_arg(layout_transform_arg), + layout_transform_inputs(layout_transform_inputs), + is_onednn_only(is_onednn_only), + dynamic_fallback(dynamic_fallback) {} }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 783ecbd567554c..c9c83d80bb8396 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -60,6 +60,7 @@ const std::unordered_set LegacyOpList = { SoftReluOp::name(), SoftReluGradOp::name()}; +const std::unordered_set OneDNNLegacyOpList = {}; enum class AttrType { UNDEFINED = 0, BOOL, @@ -220,6 +221,12 @@ VariantType GetAttributeData(const pir::Attribute& attr) { bool IsLegacyOp(const std::string& name) { return LegacyOpList.count(name); } +#ifdef PADDLE_WITH_DNNL +bool IsOneDNNLegacyOp(const std::string& name) { + return OneDNNLegacyOpList.count(name); +} +#endif + bool IsEmptyValue(const pir::Value& value) { return !value.impl() || !value.type(); } diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index 1ebe7d244affdd..0e14077bb8559d 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -132,6 +132,10 @@ VariantType GetAttributeData(const pir::Attribute& attr); bool IsLegacyOp(const std::string& name); +#ifdef PADDLE_WITH_DNNL +bool IsOneDNNLegacyOp(const std::string& name); +#endif + bool IsEmptyValue(const pir::Value& value); std::vector GetInt64Vector(const pir::Attribute& attr); diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 4731b61541e21b..7d1896de1d7cf8 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -17,6 +17,7 @@ #include #include "paddle/fluid/framework/op_kernel_type.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" @@ -44,6 +45,12 @@ #include "paddle/pir/dialect/control_flow/ir/cf_op.h" #include "paddle/utils/flags.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/trait/onednn.h" +#endif + PHI_DECLARE_bool(print_ir); namespace paddle { namespace dialect { @@ -337,6 +344,49 @@ static pir::OpResult AddPlaceTransferOp(pir::Value in, return new_in; } +#ifdef PADDLE_WITH_DNNL +static pir::OpResult AddOneDNN2PaddleLayoutTransferOp( + pir::Value in, const phi::DataLayout& dst_layout, pir::Block* block) { + pir::IrContext* ctx = pir::IrContext::Instance(); + auto in_alloc_type = in.type().dyn_cast(); + + phi::KernelKey kernel_key; + kernel_key.set_backend(phi::Backend::CPU); + kernel_key.set_layout(phi::DataLayout::ANY); + kernel_key.set_dtype(dialect::TransToPhiDataType(in_alloc_type.dtype())); + + std::unordered_map op_attribute; + op_attribute = { + {"op_name", pir::StrAttribute::get(ctx, "pd_op.onednn_to_paddle_layout")}, + {"kernel_name", pir::StrAttribute::get(ctx, "onednn_to_paddle_layout")}, + {"kernel_key", KernelAttribute::get(ctx, kernel_key)}, + {"dst_layout", + pir::Int32Attribute::get(ctx, static_cast(dst_layout))}}; + + auto out_type = AllocatedDenseTensorType::get(ctx, + in_alloc_type.place(), + in_alloc_type.dtype(), + in_alloc_type.dims(), + dst_layout, + in_alloc_type.lod(), + in_alloc_type.offset()); + + pir::OpInfo kernel_op_info = ctx->GetRegisteredOpInfo(PhiKernelOp::name()); + pir::Operation* op = + pir::Operation::Create({in}, op_attribute, {out_type}, kernel_op_info); + + auto in_op = in.dyn_cast().owner(); + if (in_op && in_op->HasAttribute(kAttrIsPersisable)) { + op->set_attribute(kAttrIsPersisable, in_op->attribute(kAttrIsPersisable)); + } + + block->push_back(op); + auto new_in = op->result(0); + + return new_in; +} +#endif + static bool NeedTransformDataType(const phi::DataType& l, const phi::DataType& r) { return l != phi::DataType::ALL_DTYPE && r != phi::DataType::ALL_DTYPE && @@ -424,6 +474,46 @@ static pir::Type BuildOutputType(pir::Type type, } } +#ifdef PADDLE_WITH_DNNL +template +static pir::Type create_type(pir::Type type, + const phi::Place& place, + const phi::DataLayout& layout, + pir::Type out_dtype, + pir::IrContext* ctx) { + auto input_type = type.dyn_cast(); + return IrType2::get(ctx, + place, + out_dtype, + input_type.dims(), + layout, + input_type.lod(), + input_type.offset()); +} + +static pir::Type BuildOutputType(pir::Type type, + const phi::Place& place, + const phi::DataLayout& layout, + pir::IrContext* ctx) { + if (type.isa()) { + auto out_dtype = type.dyn_cast().dtype(); + return create_type( + type, place, layout, out_dtype, ctx); + } else if (type.isa()) { + auto out_dtype = type.dyn_cast().dtype(); + return create_type( + type, place, layout, out_dtype, ctx); + } else if (type.isa()) { + auto array_type = type.dyn_cast(); + return AllocatedDenseTensorArrayType::get( + ctx, place, array_type.dtype(), layout); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "BuildOutputType only support DenseTensorType and SelectedRowsType")); + } +} +#endif + pir::OpResult AddDtypeTransferOp(pir::Value in, pir::Block* block, const phi::KernelKey& kernel_key, @@ -666,6 +756,49 @@ std::string GetKernelName(const OpYamlInfoParser* op_info_parser, return kernel_fn_str; } +#ifdef PADDLE_WITH_DNNL +bool SupportsMKLDNN(const std::string& kernel_name, + const phi::DataType data_type) { + auto phi_kernels = + phi::KernelFactory::Instance().SelectKernelMap(kernel_name); + auto has_phi_kernel = + std::any_of(phi_kernels.begin(), + phi_kernels.end(), + [data_type](phi::KernelKeyMap::const_reference kern_pair) { + return kern_pair.first.backend() == phi::Backend::ONEDNN && + kern_pair.first.dtype() == data_type; + }); + if (has_phi_kernel) { + return true; + } else { + auto op_kernel_iter = + paddle::framework::OperatorWithKernel::AllOpKernels().find( + phi::TransToFluidOpName(kernel_name)); + if (op_kernel_iter == + paddle::framework::OperatorWithKernel::AllOpKernels().end()) { + return false; + } else { + auto& op_kernels = op_kernel_iter->second; + return std::any_of( + op_kernels.begin(), + op_kernels.end(), + [data_type](std::unordered_map< + paddle::framework::OpKernelType, + std::function, + paddle::framework::OpKernelType::Hash>::const_reference + kern_pair) { + return platform::is_cpu_place(kern_pair.first.place_) && + kern_pair.first.library_type_ == + paddle::framework::LibraryType::kMKLDNN && + kern_pair.first.data_type_ == + paddle::framework::TransToProtoVarType(data_type); + }); + } + } +} +#endif + phi::KernelKey GetKernelKey( pir::Operation* op, const phi::Place& place, @@ -894,6 +1027,13 @@ phi::KernelKey GetKernelKey( "to GPU"; } +#ifdef PADDLE_WITH_DNNL + if (op->HasTrait() && res.backend() == phi::Backend::CPU && + SupportsMKLDNN(kernel_fn_str, res.dtype())) { + res.set_backend(phi::Backend::ONEDNN); + res.set_layout(phi::DataLayout::ONEDNN); + } +#endif return res; } @@ -1370,7 +1510,17 @@ std::vector BuildOutputs(pir::Operation* op_item, } else if (result_type.isa() || result_type.isa() || result_type.isa()) { +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() == phi::Backend::ONEDNN) { + op_output_types.push_back(BuildOutputType( + result_type, out_place, phi::DataLayout::ONEDNN, ctx)); + } else { + op_output_types.push_back(BuildOutputType(result_type, out_place, ctx)); + } +#else op_output_types.push_back(BuildOutputType(result_type, out_place, ctx)); +#endif + } else if (result_type.isa()) { std::vector vec_inner_types; auto base_types = result_type.dyn_cast().data(); @@ -1378,8 +1528,18 @@ std::vector BuildOutputs(pir::Operation* op_item, if (base_type) { if (base_type.isa() || base_type.isa()) { +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() == phi::Backend::ONEDNN) { + vec_inner_types.push_back(BuildOutputType( + base_type, out_place, phi::DataLayout::ONEDNN, ctx)); + } else { + vec_inner_types.push_back( + BuildOutputType(base_type, out_place, ctx)); + } +#else vec_inner_types.push_back( BuildOutputType(base_type, out_place, ctx)); +#endif } else { PADDLE_THROW(phi::errors::Unimplemented( "only support dense tensor and selected rows in vector type " @@ -1390,6 +1550,11 @@ std::vector BuildOutputs(pir::Operation* op_item, pir::Type fp32_dtype = pir::Float32Type::get(ctx); phi::DDim dims = {}; phi::DataLayout data_layout = phi::DataLayout::NCHW; +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() == phi::Backend::ONEDNN) { + data_layout = phi::DataLayout::ONEDNN; + } +#endif phi::LoD lod = {{}}; size_t offset = 0; auto dense_tensor_dtype = DenseTensorType::get( @@ -1458,7 +1623,21 @@ std::vector BuildInputs( } } - // 1.backend transfer + // 1. layout transfer(only for onednn) +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() != phi::Backend::ONEDNN) { + auto new_in_type = new_in.type(); + if (new_in_type.isa()) { + if (new_in_type.dyn_cast().data_layout() == + phi::DataLayout::ONEDNN) { + new_in = AddOneDNN2PaddleLayoutTransferOp( + new_in, phi::DataLayout::ANY, block); + } + } + } +#endif + + // 2.backend transfer bool check_place_transfer = (op_item->isa<::pir::SetParameterOp>()) || (kernel.IsValid() && (!UnchangeOutputOps.count(op_item->name()))); @@ -1659,7 +1838,7 @@ std::vector BuildInputs( } } - // 2. dtype transfer + // 3. dtype transfer if (op_info_parser != nullptr) { std::string var_name = op_info_parser->InputNames()[i]; auto fake_tensors = PrepareFakeTensors(new_in); @@ -1689,6 +1868,7 @@ std::vector BuildInputs( } } } + vec_inputs.push_back(new_in); } return vec_inputs; @@ -1768,18 +1948,76 @@ pir::Operation* BuildKernelOp( op_attribute.emplace("is_inplace", pir::BoolAttribute::get(ctx, true)); } - pir::OpInfo phi_kernel_op_info = - ctx->GetRegisteredOpInfo(PhiKernelOp::name()); - - pir::OpInfo legacy_kernel_op_info = - ctx->GetRegisteredOpInfo(LegacyKernelOp::name()); pir::Operation* op = nullptr; - if (IsLegacyOp(op_item->name())) { - op = pir::Operation::Create( - vec_inputs, op_attribute, op_output_types, legacy_kernel_op_info); - } else { - op = pir::Operation::Create( - vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); +#ifdef PADDLE_WITH_DNNL + if (op_item->HasTrait()) { + if (IsOneDNNLegacyOp(op_item->name())) { + VLOG(4) << "choose OneDNNLegacyKernelOp"; + pir::OpInfo legacy_kernel_op_info = + ctx->GetRegisteredOpInfo(OneDNNLegacyKernelOp::name()); + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, legacy_kernel_op_info); + } else { + auto op_info_parser = GetOpYamlInfoParser(op_item); + std::vector extra_args; + for (auto& arg : op_info_parser->OpRuntimeInfo().extra_args) { + extra_args.push_back(pir::StrAttribute::get(ctx, arg)); + } + op_attribute.emplace( + "extra_args", + pir::ArrayAttribute::get(pir::IrContext::Instance(), extra_args)); + op_attribute.emplace( + "layout_transform_arg", + pir::StrAttribute::get( + ctx, op_info_parser->OpRuntimeInfo().layout_transform_arg)); + std::vector layout_transform_inputs; + for (auto& input : + op_info_parser->OpRuntimeInfo().layout_transform_inputs) { + layout_transform_inputs.push_back(pir::StrAttribute::get(ctx, input)); + } + op_attribute.emplace("layout_transform_inputs", + pir::ArrayAttribute::get(pir::IrContext::Instance(), + layout_transform_inputs)); + op_attribute.emplace( + "is_onednn_only", + pir::BoolAttribute::get( + ctx, op_info_parser->OpRuntimeInfo().is_onednn_only)); + op_attribute.emplace( + "dynamic_fallback", + pir::BoolAttribute::get( + ctx, op_info_parser->OpRuntimeInfo().dynamic_fallback)); + if (op_item->HasTrait()) { + VLOG(4) << "choose OneDNNMixedPhiKernelOp"; + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(OneDNNMixedPhiKernelOp::name()); + + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); + } else { + VLOG(4) << "choose OneDNNPhiKernelOp"; + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(OneDNNPhiKernelOp::name()); + + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); + } + } + } else // NOLINT +#endif + { + if (IsLegacyOp(op_item->name())) { + pir::OpInfo legacy_kernel_op_info = + ctx->GetRegisteredOpInfo(LegacyKernelOp::name()); + + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, legacy_kernel_op_info); + } else { + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(PhiKernelOp::name()); + + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); + } } (*map_op_pair)[op_item] = op; @@ -1804,10 +2042,11 @@ void ProcessBlock( std::unordered_map* map_value_pair) { auto inputs_by_data_op = GetInputsByDataOp(block); - for (auto& op_item : *block) { - VLOG(6) << "op name " << op_item.name(); - if ((op_item.isa()) && - inputs_by_data_op.count(op_item.attributes() + for (auto iter = block->begin(); iter != block->end(); ++iter) { + pir::Operation* op_item = &(*iter); + VLOG(6) << "op name " << op_item->name(); + if ((op_item->isa()) && + inputs_by_data_op.count(op_item->attributes() .at("name") .dyn_cast() .AsString())) { @@ -1816,24 +2055,55 @@ void ProcessBlock( } // HandleSpecialOp - if (SpecialLowerOps.count(op_item.name())) { - VLOG(6) << "Handle Special Op: [" << op_item.name() + if (SpecialLowerOps.count(op_item->name())) { + VLOG(6) << "Handle Special Op: [" << op_item->name() << "] while lowering to kernel pass"; HandleForSpecialOp( - place, &op_item, new_block, ctx, map_op_pair, map_value_pair); + place, op_item, new_block, ctx, map_op_pair, map_value_pair); continue; } - auto op_info_parser = GetOpYamlInfoParser(&op_item); - auto kernel_name = GetKernelName(op_info_parser.get(), &op_item); + auto op_info_parser = GetOpYamlInfoParser(op_item); + auto kernel_name = GetKernelName(op_info_parser.get(), op_item); auto kernel_key = GetKernelKey( - &op_item, place, kernel_name, *map_value_pair, op_info_parser.get()); + op_item, place, kernel_name, *map_value_pair, op_info_parser.get()); VLOG(6) << "kernel type " << kernel_key; +#ifdef PADDLE_WITH_DNNL + if (op_item->HasTrait() && + kernel_key.backend() != phi::Backend::ONEDNN) { + std::vector op_item_inner_output_types; + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + op_item_inner_output_types.push_back(op_item->result_type(i)); + } + } + std::string target_op_name = op_item->name(); + target_op_name.replace(0, 12, "pd_op"); + auto op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW("Ctx should have corresponding OpInfo %s", target_op_name); + } + pir::Operation* op_item_inner = + pir::Operation::Create(op_item->operands_source(), + op_item->attributes(), + op_item_inner_output_types, + op_info); + op_item->ReplaceAllUsesWith(op_item_inner->results()); + for (auto iter = block->begin(); iter != block->end(); ++iter) { + if (*iter == *op_item) { + block->Assign(iter, op_item_inner); + break; + } + } + op_item = op_item_inner; + op_info_parser = GetOpYamlInfoParser(op_item_inner); + } +#endif // build output type - auto op_output_types = BuildOutputs(&op_item, kernel_name, kernel_key, ctx); + auto op_output_types = BuildOutputs(op_item, kernel_name, kernel_key, ctx); // build input - auto vec_inputs = BuildInputs(&op_item, + auto vec_inputs = BuildInputs(op_item, kernel_name, kernel_key, place, @@ -1848,14 +2118,14 @@ void ProcessBlock( kernel_key, vec_inputs, op_output_types, - &op_item, + op_item, new_block, ctx, map_op_pair, map_value_pair); AddShadowFeedOpForDataOrFeed( - place, &op_item, op, new_block, ctx, map_op_pair, map_value_pair); + place, op_item, op, new_block, ctx, map_op_pair, map_value_pair); } } @@ -1872,7 +2142,10 @@ std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - +#ifdef PADDLE_WITH_DNNL + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); +#endif std::unordered_map map_op_pair; std::unordered_map map_value_pair; diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index dd3166f05c3ef9..e0509fa8582ae2 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -178,6 +178,11 @@ inline bool NeedTransformPlace(const phi::Place& src_place, (target != Backend::ALL_BACKEND && phi::TransToPhiBackend(src_place) != (target != Backend::GPUDNN ? target : Backend::GPU)); +#ifdef PADDLE_WITH_DNNL + if (target == Backend::ONEDNN) { + ret = src_place.GetType() != AllocationType::CPU; + } +#endif return ret; } diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index f0e87043d965dd..7cbbc837cecf45 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2466,6 +2466,15 @@ outputs : {q : Q, r : R} +- op : quantize + backward : quantize_grad + inputs : + input : Input + outputs : + output : Output + attrs : + {scale : Scale, shift : Shift, include_self: Include_self} + - op : quantize_linear extra : attrs : [float moving_rate = 0.9] diff --git a/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc b/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc new file mode 100644 index 00000000000000..eba8b2b61f4d27 --- /dev/null +++ b/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/onednn_to_paddle_layout_kernel.h" + +#include +#include + +#include "glog/logging.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/data_layout_transform.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/memcpy_kernel.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/phi/backends/onednn/onednn_helper.h" +#endif +namespace phi { + +template +void OneDNN2PaddleLayout(const Context& dev_ctx, + const DenseTensor& x, + int dst_layout, + DenseTensor* out) { +#ifdef PADDLE_WITH_DNNL + DataLayout src_layout = x.layout(); + VLOG(10) << "TransDataLayout from " << static_cast(src_layout) + << " -> " << static_cast(dst_layout); + + auto print_tensor_meta = [](const DenseTensor& x) { + std::ostringstream oss; + + oss << "["; + oss << "layout:" << x.layout() << " ,"; + oss << "dims:" << x.dims() << " ,"; + if (x.IsInitialized()) oss << "place:" << x.place(); + oss << "]"; + + return oss.str(); + }; + VLOG(10) << " x: " << print_tensor_meta(x); + VLOG(10) << " out: " << print_tensor_meta(*out) << " " << out; + + if (src_layout != DataLayout::ONEDNN) { + out->ShareDataWith(x); + out->ShareInplaceVersionCounterWith(x); + out->set_layout(static_cast(dst_layout)); + return; + } + + DataLayout tmp_layout = static_cast(dst_layout); + if (static_cast(dst_layout) == DataLayout::ANY) { + tmp_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout(); + } + + if (tmp_layout == DataLayout::ANY) { + tmp_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout(); + } + + // NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in + // data_transfer.cc + if (!x.IsInitialized() && src_layout == DataLayout::ONEDNN && + tmp_layout == DataLayout::NHWC) { + VLOG(4) << src_layout << "->" << tmp_layout << " " << x.layout(); + out->Resize(x.dims()); + out->set_layout(tmp_layout); + funcs::MatchShapeToLayout(out, src_layout, tmp_layout); + return; + } + + funcs::TransDataLayoutFromOneDNN( + src_layout, tmp_layout, x, out, dev_ctx.GetPlace()); +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL_FOR_ALL_DTYPE(onednn_to_paddle_layout, + CPU, + ALL_LAYOUT, + phi::OneDNN2PaddleLayout) {} diff --git a/paddle/phi/kernels/onednn_to_paddle_layout_kernel.h b/paddle/phi/kernels/onednn_to_paddle_layout_kernel.h new file mode 100644 index 00000000000000..a6ddc280c4e3c8 --- /dev/null +++ b/paddle/phi/kernels/onednn_to_paddle_layout_kernel.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { + +template +void OneDNN2PaddleLayout(const Context& dev_ctx, + const DenseTensor& x, + int dst_layout, + DenseTensor* out); +} // namespace phi diff --git a/test/mkldnn/test_conv2d_mkldnn_op.py b/test/mkldnn/test_conv2d_mkldnn_op.py index 3c77581acf80db..2d6cafdbc3734b 100644 --- a/test/mkldnn/test_conv2d_mkldnn_op.py +++ b/test/mkldnn/test_conv2d_mkldnn_op.py @@ -17,6 +17,9 @@ import numpy as np from op_test import OpTest, skip_check_grad_ci from test_conv2d_op import TestConv2DOp, TestConv2DOp_v2 +from utils import compare_legacy_with_pt + +from paddle.base import core def conv2d_bias_naive(out, bias): @@ -113,6 +116,94 @@ def setUp(self): self.outputs['Output'] = output +class TestConv2DMKLDNNOp2(TestConv2DOp): + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + self.data_format = "NCHW" + self.use_mkldnn = True + self._cpu_only = True + self.dtype = np.float32 + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def setUp(self): + self.fuse_bias = False + self.bias_size = None + self.fuse_activation = "" + self.fuse_alpha = 0 + self.fuse_beta = 0 + self.fuse_residual_connection = False + self.input_residual_size = None + + TestConv2DOp.setUp(self) + + output = self.outputs['Output'] + + # mkldnn only support either conv-sum-relu, or conv-relu. + if self.fuse_bias and self.bias_size is not None: + bias = np.random.random(self.bias_size).astype(self.dtype) + output = conv2d_bias_naive(output, bias) + output = output.astype(self.dtype) + self.attrs['fuse_bias'] = self.fuse_bias + self.inputs['Bias'] = OpTest.np_dtype_to_base_dtype(bias) + + if ( + self.fuse_residual_connection + and self.input_residual_size is not None + ): + input_residual = np.random.random(self.input_residual_size).astype( + self.dtype + ) + output = conv2d_residual_naive(output, input_residual) + + self.attrs[ + 'fuse_residual_connection' + ] = self.fuse_residual_connection + self.inputs['ResidualData'] = OpTest.np_dtype_to_base_dtype( + input_residual + ) + + if self.fuse_activation == "relu": + output = np.maximum(output, 0).astype(self.dsttype) + + if self.fuse_activation == "relu6": + output = np.minimum(np.maximum(output, 0), self.fuse_beta).astype( + self.dsttype + ) + if ( + self.fuse_activation != "" + or self.fuse_bias + or self.fuse_residual_connection + ): + self.op_type = 'fused_conv2d' + + output = output.astype(self.dtype) + + self.attrs['fuse_bias'] = self.fuse_bias + self.attrs['fuse_activation'] = self.fuse_activation + self.attrs['fuse_alpha'] = self.fuse_alpha + self.attrs['fuse_beta'] = self.fuse_beta + self.attrs['fuse_residual_connection'] = self.fuse_residual_connection + + self.outputs['Output'] = output + + @compare_legacy_with_pt + def test_check_output(self): + place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace() + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_output_with_place( + place, atol=1e-5, check_dygraph=(not self.use_mkldnn) + ) + + @skip_check_grad_ci( reason="Fusion is for inference only, check_grad is not required." )