Skip to content

Commit

Permalink
[PHI] Migrate softplus kernel (PaddlePaddle#47406)
Browse files Browse the repository at this point in the history
* add extra attr property set

* add type_info for all context

* add onednn context to all context

* fix context compile error

* simplify conv kernel args

* pass runtime attr into dev_ctx

* fix marco error

* clear conv_grad_kernel extra args

* merge conv_grad_grad into conv_grad

* clear conv2d_grad_grad extra attrs

* remove redundant imports

* migrate softmax

* clear yaml and eager extra attr

* fix conv1d error

* change to thread local

* fix npu compile failed

* try to fix windows compile failed

* add conv2d onednn phi kernel

* fix ci bugs (#36)

* fix compile bugs (#38)

* fix extra input transform bug (#39)

* support dynamic created attr (#40)

* reset extra info gen code

* rm conv_grad_grad kernel

* reimpl pass attr adapting

* add int attr support

* remove vector inputnames creating

* merge dev

* fix map at error

* adjust attribute

* adapt funcs to PHI

* init

* adjust imports

* support postops

* format codeblocks

* revert changes to softmax

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
  • Loading branch information
3 people authored Nov 4, 2022
1 parent c2483af commit 1831919
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 170 deletions.
65 changes: 0 additions & 65 deletions paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc

This file was deleted.

105 changes: 0 additions & 105 deletions paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h

This file was deleted.

100 changes: 100 additions & 0 deletions paddle/phi/kernels/onednn/softplus_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/activation_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T>
class SoftplusOneDNNHandler
: public funcs::OneDNNHandlerNoCachingT<T, dnnl::binary> {
public:
SoftplusOneDNNHandler(const OneDNNContext& dev_ctx,
const phi::DenseTensor* x,
const float beta)
: funcs::OneDNNHandlerNoCachingT<T, dnnl::binary>(dev_ctx.GetEngine(),
dev_ctx.GetPlace()) {
dnnl::post_ops post_ops;
post_ops.append_eltwise(
1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f, 0.0f);
if (beta != 1.0f) {
post_ops.append_eltwise(
1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f);
}
funcs::AppendActivation(dev_ctx, post_ops);
dnnl::primitive_attr attrs;
attrs.set_post_ops(post_ops);

auto x_tz = phi::vectorize(x->dims());
auto beta_tz = std::vector<int64_t>(x_tz.size(), 1);
auto beta_md = dnnl::memory::desc(beta_tz,
funcs::OneDNNGetDataType<T>(),
funcs::GetPlainOneDNNFormat(x_tz.size()));

this->AcquireForwardPrimitiveDescriptor(attrs,
dnnl::algorithm::binary_mul,
x->mem_desc(),
beta_md,
x->mem_desc());
}

std::shared_ptr<dnnl::memory> AcquireBetaMemory(const float* beta) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(),
funcs::to_void_cast<float>(beta));
}
};

template <typename T, typename Context>
void SoftplusKernel(const Context& dev_ctx,
const DenseTensor& x,
float beta,
float threshold,
DenseTensor* out) {
SoftplusOneDNNHandler<T> handler(dev_ctx, &x, beta);

auto src_memory_p = handler.AcquireSrcMemory(&x);
auto beta_memory_p = handler.AcquireBetaMemory(&beta);
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (x.IsSharedBufferWith(*out)) {
dst_memory_p = src_memory_p;
dev_ctx.template Alloc<T>(out);
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto binary_p = handler.AcquireForwardPrimitive();

auto& astream = OneDNNContext::tls().get_stream();

const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_memory_p},
{DNNL_ARG_SRC_1, *beta_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};

binary_p->execute(astream, args);
astream.wait();

out->set_mem_desc(dst_memory_p->get_desc());
}

} // namespace phi

PD_REGISTER_KERNEL(softplus,
OneDNN,
ONEDNN,
phi::SoftplusKernel,
float,
phi::dtype::bfloat16) {}

0 comments on commit 1831919

Please sign in to comment.