diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc deleted file mode 100644 index 0af0b49076795..0000000000000 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace phi { -class DenseTensor; -} // namespace phi - -namespace paddle { -namespace operators { - -using dnnl::memory; -using dnnl::primitive; -using dnnl::stream; -using phi::DataLayout; - -using platform::GetMKLDNNFormat; -using platform::MKLDNNDeviceContext; -using platform::to_void_cast; - -template -class MKLDNNActivationKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - Functor functor; - functor(ctx); - } -}; - -template -struct SoftplusMKLDNNFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - custom_softplus_eltwise_forward(ctx); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -#define REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(act_type, functor) \ - REGISTER_OP_KERNEL( \ - act_type, \ - MKLDNN, \ - ::paddle::platform::CPUPlace, \ - ops::MKLDNNActivationKernel>, \ - ops::MKLDNNActivationKernel>); - -REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor); diff --git a/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h b/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h deleted file mode 100644 index 25886c5791fea..0000000000000 --- a/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace paddle { -namespace operators { - -template -class SoftplusMKLDNNHandler - : public platform::MKLDNNHandlerNoCachingT { - public: - SoftplusMKLDNNHandler(const framework::ExecutionContext& ctx, - const phi::DenseTensor* x, - const float beta, - const dnnl::engine engine) - : platform::MKLDNNHandlerNoCachingT(engine, - ctx.GetPlace()) { - auto x_tz = phi::vectorize(x->dims()); - - auto beta_tz = std::vector(x_tz.size(), 1); - auto beta_md = - dnnl::memory::desc(beta_tz, - platform::MKLDNNGetDataType(), - platform::GetPlainMKLDNNFormat(x_tz.size())); - - dnnl::post_ops post_ops; - post_ops.append_eltwise( - 1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f, 0.0f); - if (beta != 1.0f) { - post_ops.append_eltwise( - 1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f); - } - - platform::AppendActivation(ctx, post_ops); - - dnnl::primitive_attr attrs; - attrs.set_post_ops(post_ops); - - this->AcquireForwardPrimitiveDescriptor(attrs, - dnnl::algorithm::binary_mul, - x->mem_desc(), - beta_md, - x->mem_desc()); - } - - std::shared_ptr AcquireBetaMemory(const float* beta) { - return this->AcquireMemoryFromPrimitive( - this->fwd_pd_->src1_desc(), platform::to_void_cast(beta)); - } -}; - -template -void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { - const auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - const auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - - bool is_inplaced = x->IsSharedBufferWith(*out); - - const float beta = ctx.Attr("beta"); - - SoftplusMKLDNNHandler handler(ctx, x, beta, mkldnn_engine); - - auto src_memory_p = handler.AcquireSrcMemory(x); - - auto beta_memory_p = handler.AcquireBetaMemory(&beta); - std::shared_ptr dst_memory_p = nullptr; - if (is_inplaced) { - dst_memory_p = src_memory_p; - out->mutable_data(ctx.GetPlace()); - } else { - dst_memory_p = handler.AcquireDstMemory(out); - } - auto binary_p = handler.AcquireForwardPrimitive(); - - auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - - const std::unordered_map args = { - {DNNL_ARG_SRC_0, *src_memory_p}, - {DNNL_ARG_SRC_1, *beta_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - binary_p->execute(astream, args); - astream.wait(); - - out->set_mem_desc(dst_memory_p->get_desc()); -} -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/onednn/softplus_kernel.cc b/paddle/phi/kernels/onednn/softplus_kernel.cc new file mode 100644 index 0000000000000..b87938e3dc11b --- /dev/null +++ b/paddle/phi/kernels/onednn/softplus_kernel.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/activation_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +class SoftplusOneDNNHandler + : public funcs::OneDNNHandlerNoCachingT { + public: + SoftplusOneDNNHandler(const OneDNNContext& dev_ctx, + const phi::DenseTensor* x, + const float beta) + : funcs::OneDNNHandlerNoCachingT(dev_ctx.GetEngine(), + dev_ctx.GetPlace()) { + dnnl::post_ops post_ops; + post_ops.append_eltwise( + 1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f, 0.0f); + if (beta != 1.0f) { + post_ops.append_eltwise( + 1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f); + } + funcs::AppendActivation(dev_ctx, post_ops); + dnnl::primitive_attr attrs; + attrs.set_post_ops(post_ops); + + auto x_tz = phi::vectorize(x->dims()); + auto beta_tz = std::vector(x_tz.size(), 1); + auto beta_md = dnnl::memory::desc(beta_tz, + funcs::OneDNNGetDataType(), + funcs::GetPlainOneDNNFormat(x_tz.size())); + + this->AcquireForwardPrimitiveDescriptor(attrs, + dnnl::algorithm::binary_mul, + x->mem_desc(), + beta_md, + x->mem_desc()); + } + + std::shared_ptr AcquireBetaMemory(const float* beta) { + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(), + funcs::to_void_cast(beta)); + } +}; + +template +void SoftplusKernel(const Context& dev_ctx, + const DenseTensor& x, + float beta, + float threshold, + DenseTensor* out) { + SoftplusOneDNNHandler handler(dev_ctx, &x, beta); + + auto src_memory_p = handler.AcquireSrcMemory(&x); + auto beta_memory_p = handler.AcquireBetaMemory(&beta); + std::shared_ptr dst_memory_p = nullptr; + if (x.IsSharedBufferWith(*out)) { + dst_memory_p = src_memory_p; + dev_ctx.template Alloc(out); + } else { + dst_memory_p = handler.AcquireDstMemory(out); + } + auto binary_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + + const std::unordered_map args = { + {DNNL_ARG_SRC_0, *src_memory_p}, + {DNNL_ARG_SRC_1, *beta_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + binary_p->execute(astream, args); + astream.wait(); + + out->set_mem_desc(dst_memory_p->get_desc()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(softplus, + OneDNN, + ONEDNN, + phi::SoftplusKernel, + float, + phi::dtype::bfloat16) {}