Skip to content

Commit

Permalink
[PHI] Migrate squeeze and squeeze_grad kernels (#48634)
Browse files Browse the repository at this point in the history
* squeeze kernel

* squeze fwd

* whitespace
  • Loading branch information
Silv3S authored Dec 7, 2022
1 parent 4aad4dc commit ad41fce
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 36 deletions.
39 changes: 3 additions & 36 deletions paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ enum class ReshapeKernelOpName {
reshape,
reshape2,
squeeze,
squeeze2,
flatten,
flatten2,
};
Expand Down Expand Up @@ -106,9 +105,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
case ReshapeKernelOpName::squeeze:
InferShapeSqueezeOp(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::squeeze2:
InferShapeSqueeze2Op(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::flatten:
InferShapeFlattenOp(ctx, x_dims, out_dims);
break;
Expand Down Expand Up @@ -172,16 +168,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
out_dims = GetOutputShape(axes, x_dims, true);
}

void InferShapeSqueeze2Op(const framework::ExecutionContext& ctx,
framework::DDim& x_dims, // NOLINT
framework::DDim& out_dims) const { // NOLINT
auto* out = ctx.Output<phi::DenseTensor>("Out");
auto* xshape = ctx.Output<phi::DenseTensor>("XShape");
auto xshape_dims = xshape->dims();
x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
out_dims = out->dims();
}

void InferShapeFlattenOp(const framework::ExecutionContext& ctx,
framework::DDim& x_dims, // NOLINT
framework::DDim& out_dims) const { // NOLINT
Expand Down Expand Up @@ -342,19 +328,16 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
InferShapeReshapeSqueezeGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::reshape2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
InferShapeReshape2Flatten2GradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::squeeze:
InferShapeReshapeSqueezeGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::squeeze2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::flatten:
InferShapeFlattenGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::flatten2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
InferShapeReshape2Flatten2GradOp(ctx, x_dims);
break;
default:
PADDLE_THROW(paddle::platform::errors::OutOfRange(
Expand All @@ -369,7 +352,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
dx_dims = dx->dims();
}

void InferShapeReshape2Squeeze2Flatten2GradOp(
void InferShapeReshape2Flatten2GradOp(
const framework::ExecutionContext& ctx,
framework::DDim& dx_dims) const { // NOLINT
auto xshape_dims = ctx.Input<phi::DenseTensor>("XShape")->dims();
Expand Down Expand Up @@ -401,22 +384,6 @@ REGISTER_OP_KERNEL(
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze>);

REGISTER_OP_KERNEL(
squeeze2,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::squeeze2>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze2>);

REGISTER_OP_KERNEL(
squeeze2_grad,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::squeeze2>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze2>);

REGISTER_OP_KERNEL(
reshape,
MKLDNN,
Expand Down
59 changes: 59 additions & 0 deletions paddle/phi/kernels/onednn/squeeze_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/squeeze_grad_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void SqueezeGradKernel(const Context& dev_ctx,
const DenseTensor& xshape,
const DenseTensor& dout,
const IntArray& axes,
DenseTensor* dx) {
auto dout_vec_dims = vectorize(dout.dims());
auto dout_type = funcs::ToOneDNNDataType(dout.dtype());

funcs::ReorderOneDNNHandler reorder_handler(
dout_vec_dims, dout.dtype(), dout_type, dev_ctx.GetEngine());

auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout.mem_desc(), funcs::to_void_cast(dout.data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx,
funcs::GetPlainOneDNNFormat(dout_vec_dims.size()),
dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);

auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();

auto dx_dims = slice_ddim(xshape.dims(), 1, xshape.dims().size());
dx->Resize(dx_dims);
reorder_dst_memory_p->get_desc().reshape(vectorize(dx_dims));
}

} // namespace phi

PD_REGISTER_KERNEL(squeeze_grad,
OneDNN,
ONEDNN,
phi::SqueezeGradKernel,
float,
phi::dtype::bfloat16) {}
85 changes: 85 additions & 0 deletions paddle/phi/kernels/onednn/squeeze_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/squeeze_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"

namespace phi {

template <typename T, typename Context>
void ExecuteSqueeze(const Context& dev_ctx,
const DenseTensor& x,
const DDim& x_dims,
const DDim& out_dims,
DenseTensor* out) {
auto x_vec_dims = vectorize(x_dims);

funcs::ReorderOneDNNHandler reorder_handler(
x_vec_dims,
x.dtype(),
funcs::ToOneDNNDataType(x.dtype()),
dev_ctx.GetEngine());

auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x.mem_desc(), funcs::to_void_cast(x.data<T>()));
out->Resize(x_dims); // to match x numel, format is changed later
// reorder is done into a plain tag to allow usage with blocked formats
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, funcs::GetPlainOneDNNFormat(x_dims.size()), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();

out->Resize(out_dims);
out->set_mem_desc(
reorder_dst_memory_p->get_desc().reshape(vectorize(out_dims)));
}

template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
auto x_dims = x.dims();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
ExecuteSqueeze<T, Context>(dev_ctx, x, x_dims, out_dims, out);
}

template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
auto out_dims = out->dims();
ExecuteSqueeze<T, Context>(dev_ctx, x, x_dims, out_dims, out);
}
} // namespace phi

PD_REGISTER_KERNEL(
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(squeeze_with_xshape,
OneDNN,
ONEDNN,
phi::SqueezeWithXShapeKernel,
float,
phi::dtype::bfloat16) {}

0 comments on commit ad41fce

Please sign in to comment.