From 22c4c189ee5d6e09c3b56f908411b2913ab57c80 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Fri, 30 Jul 2021 10:44:49 +0200 Subject: [PATCH] Added reshape, reshape2, squeeze and squeeze2 BF16/FP32 FWD/BWD kernels (#34219) * test version of matmul_v2 * added matmul_v2 grad kernel * minor changes * minor changes * minor change for CI approval * CI fix * CI fix * added squeeze and squeeze2 kernels * CI fix * CI fix * CI fix * disabled tests when compiled with cuda * added setting format_tag by strides * added sigmoid BF16 FWD/BWD and gelu BF16 BWD * changes after review * Revert "added sigmoid BF16 FWD/BWD and gelu BF16 BWD" This reverts commit 6e3f76720b545abfcff9f6052b46b73a1e745cae. * Revert "Merge branch 'matmul_v2_grad' into squeeze2_op" This reverts commit 06fcf67843a4a7884eccdf67a02a03575e1d4cb8, reversing changes made to 6e3f76720b545abfcff9f6052b46b73a1e745cae. * minor change * added reshape1/2 kernels * moved some functions into private block * CI fix * CI fix * CI fix --- .../framework/ir/graph_pattern_detector.cc | 29 +- .../operators/mkldnn/reshape_mkldnn_op.cc | 290 ++++++++++++++++++ paddle/fluid/operators/reshape_op.cc | 45 ++- paddle/fluid/operators/squeeze_op.cc | 65 +++- .../mkldnn/test_reshape_mkldnn_op.py | 217 +++++++++++++ .../mkldnn/test_squeeze2_mkldnn_op.py | 162 ++++++++++ 6 files changed, 770 insertions(+), 38 deletions(-) create mode 100644 paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc mode change 100755 => 100644 paddle/fluid/operators/squeeze_op.cc create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_reshape_mkldnn_op.py create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 9d06a4de9548d..70e48755dcd1e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2262,26 +2262,15 @@ PDNode *patterns::QuantizePlacement::operator()( PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { std::unordered_set supported_op_types = - std::unordered_set({"concat", - "conv2d", - "conv2d_transpose", - "elementwise_add", - "elementwise_mul", - "fc", - "fusion_gru", - "fusion_lstm", - "gelu", - "layer_norm", - "matmul", - "matmul_v2", - "pool2d", - "prelu", - "relu", - "reshape2", - "softmax", - "split", - "sum", - "transpose2"}); + std::unordered_set( + {"concat", "conv2d", "conv2d_transpose", + "elementwise_add", "elementwise_mul", "fc", + "fusion_gru", "fusion_lstm", "gelu", + "layer_norm", "matmul", "matmul_v2", + "pool2d", "prelu", "relu", + "reshape2", "softmax", "split", + "squeeze", "squeeze2", "sum", + "transpose2"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } diff --git a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc new file mode 100644 index 0000000000000..244430e69f234 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc @@ -0,0 +1,290 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/squeeze_op.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using platform::to_void_cast; +using platform::GetMKLDNNFormat; + +template +class ReshapeMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + RunKernel(ctx); + } + + private: + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto* x = ctx.Input("X"); + auto* xshape = ctx.Output("XShape"); + auto* out = ctx.Output("Out"); + + framework::DDim x_dims; + // if reshape or squeeze + if (ctx.Type().find("2") == std::string::npos) { + x_dims = x->dims(); + } else { + auto xshape_dims = xshape->dims(); + x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + } + + auto x_vec_dims = framework::vectorize(x_dims); + + framework::DDim out_dims; + if (ctx.Type() == "squeeze") { + auto& axes = ctx.Attr>("axes"); + out_dims = GetOutputShape(axes, x_dims, true); + } else { + out_dims = out->dims(); + } + + if (ctx.Type().find("reshape") != std::string::npos) { + if (ctx.HasInput("Shape")) { + auto* shape_tensor = ctx.Input("Shape"); + auto* shape_data = shape_tensor->data(); + + auto shape = + std::vector(shape_data, shape_data + shape_tensor->numel()); + out_dims = ValidateShape(shape, x_dims); + } + } + + mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); + std::string key = + platform::CreateKey(dev_ctx, x_vec_dims, x->format(), x_type); + platform::ReorderMKLDNNHandler reorder_handler( + x_vec_dims, x->type(), x_type, dev_ctx, onednn_engine, key); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x->format(), platform::to_void_cast(x->data())); + out->Resize(x_dims); // to match x numel, format is changed later + // reorder is done into a plain tag to allow usage with blocked formats + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + out, getPlainFormatTag(x), ctx.GetPlace()); + auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, + reorder_dst_memory_p); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + + astream.wait(); + + out->Resize(out_dims); + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape( + framework::vectorize(out_dims)))); + } + + protected: + static mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) { + auto tensor_dims_size = tensor->dims().size(); + PADDLE_ENFORCE_EQ( + tensor_dims_size <= 6 && tensor_dims_size >= 1, true, + platform::errors::InvalidArgument( + "Dims for squeeze_grad oneDNN op must be in range <1, 6>")); + + switch (tensor_dims_size) { + case 1: + return mkldnn::memory::format_tag::a; + case 2: + return mkldnn::memory::format_tag::ab; + case 3: + return mkldnn::memory::format_tag::abc; + case 4: + return mkldnn::memory::format_tag::abcd; + case 5: + return mkldnn::memory::format_tag::abcde; + default: + return mkldnn::memory::format_tag::abcdef; + } + } + + static framework::DDim ValidateShape(const std::vector& shape, + const framework::DDim& in_dims) { + const int64_t in_size = framework::product(in_dims); + auto in_dims_vec = framework::vectorize(in_dims); + bool all_positive = std::all_of(in_dims_vec.cbegin(), in_dims_vec.cend(), + [](int64_t i) { return i > 0; }); + // only one dimension can be set to -1, whose size will be automatically + // infered + const int64_t unk_dim_val = -1; + const int64_t copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + int64_t capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + PADDLE_ENFORCE_EQ( + unk_dim_idx, -1, + platform::errors::InvalidArgument( + "Only one dimension value of 'shape' in ReshapeOp can " + "be -1. But received shape = [%s], shape[%d] is also -1.", + framework::make_ddim(shape), i)); + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + PADDLE_ENFORCE_LT( + static_cast(i), in_dims.size(), + platform::errors::InvalidArgument( + "The index of 0 in `shape` must be less than " + "the input tensor X's dimensions. " + "But received shape = [%s], shape[%d] = 0, X's shape = [%s], " + "X's dimensions = %d.", + framework::make_ddim(shape), i, in_dims, in_dims.size())); + } else { + PADDLE_ENFORCE_GT( + shape[i], 0, + platform::errors::InvalidArgument( + "Each dimension value of 'shape' in ReshapeOp must not " + "be negative except one unknown dimension. " + "But received shape = [%s], shape[%d] = %d.", + framework::make_ddim(shape), i, shape[i])); + } + + capacity *= (shape[i] ? shape[i] : in_dims[i]); + output_shape[i] = + (shape[i] ? static_cast(shape[i]) : in_dims[i]); + } + + if (unk_dim_idx != -1) { + if (all_positive) { + // in_size < 0 and is un-determinate in compile time, skip the check, + // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, in_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -in_size / capacity; + PADDLE_ENFORCE_EQ( + output_shape[unk_dim_idx] * capacity, -in_size, + platform::errors::InvalidArgument( + "The 'shape' attribute in ReshapeOp is invalid. " + "The input tensor X'size must be divisible by known " + "capacity of 'shape'. " + "But received X's shape = [%s], X's size = %d, " + "'shape' is [%s], known capacity of 'shape' is %d.", + in_dims, in_size, framework::make_ddim(shape), capacity)); + } else { + output_shape[unk_dim_idx] = -1; + } + } else { + if (all_positive) { + PADDLE_ENFORCE_EQ( + capacity, in_size, + platform::errors::InvalidArgument( + "The 'shape' in ReshapeOp is invalid. " + "The input tensor X'size must be equal to the capacity of " + "'shape'. " + "But received X's shape = [%s], X's size = %d, 'shape' is " + "[%s], the capacity of 'shape' is %d.", + in_dims, in_size, framework::make_ddim(shape), capacity)); + } + } + return framework::make_ddim(output_shape); + } +}; + +template +class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + RunKernel(ctx); + } + + private: + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + framework::DDim x_dims; + // if reshape or squeeze + if (ctx.Type().find("2") == std::string::npos) { + x_dims = dx->dims(); + } else { + auto xshape_dims = ctx.Input("XShape")->dims(); + x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + } + auto dout_vec_dims = framework::vectorize(dout->dims()); + + mkldnn::memory::data_type dout_type = + framework::ToMKLDNNDataType(dout->type()); + std::string key = + platform::CreateKey(dev_ctx, dout_vec_dims, this->getPlainFormatTag(dx), + dx->format(), dout_type); + platform::ReorderMKLDNNHandler reorder_handler( + dout_vec_dims, dout->type(), dout_type, dev_ctx, onednn_engine, key); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + dout->format(), platform::to_void_cast(dout->data())); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + dx, this->getPlainFormatTag(dout), ctx.GetPlace()); + auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, + reorder_dst_memory_p); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + dx->Resize(x_dims); + dx->set_layout(framework::DataLayout::kMKLDNN); + dx->set_format(GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape( + framework::vectorize(x_dims)))); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(squeeze, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL(squeeze_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); + +REGISTER_OP_KERNEL(squeeze2, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL(squeeze2_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); + +REGISTER_OP_KERNEL(reshape, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL(reshape_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); + +REGISTER_OP_KERNEL(reshape2, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL(reshape2_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 717029cb8f117..c9c1750b8569f 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -228,9 +228,17 @@ class ReshapeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( @@ -269,6 +277,9 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { "It has the lowest priority compare with Input(Shape) and " " Input(ShapeTensor).") .SetDefault({}); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Reshape Operator. @@ -334,9 +345,17 @@ class ReshapeGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -517,9 +536,17 @@ class Reshape2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc old mode 100755 new mode 100644 index ff4ec2f532474..866431951f051 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -110,9 +110,17 @@ class SqueezeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -129,9 +137,17 @@ class SqueezeGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -144,6 +160,14 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { "(std::vector). List of integers," " indicating the dimensions to squeeze.") .SetDefault({}); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); AddComment(R"DOC( Squeeze Operator. @@ -209,6 +233,21 @@ class Squeeze2Op : public framework::OperatorWithKernel { ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); ctx->ShareLoD("X", /*->*/ "XShape"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; template @@ -243,9 +282,17 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_mkldnn_op.py new file mode 100644 index 0000000000000..1389421586638 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_mkldnn_op.py @@ -0,0 +1,217 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 + + +@OpTestTool.skip_if(core.is_compiled_with_cuda(), + "CUDA has to be skipped because it forces dygraph") +class TestReshape2OneDNNOp(OpTest): + def setUp(self): + self.init_data() + self.set_op_type() + self.x = np.random.random(self.ori_shape).astype("float32") + self.set_inputs() + self.set_additional_inputs() + self.set_attrs() + self.set_outputs() + + def set_op_type(self): + self.op_type = "reshape2" + + def set_inputs(self): + self.inputs = {"X": self.x} + + def set_additional_inputs(self): + pass + + def set_attrs(self): + self.attrs = {"shape": self.new_shape, 'use_mkldnn': True} + + def set_outputs(self): + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + 'XShape': np.random.random(self.ori_shape).astype("float32") + } + + def init_data(self): + self.ori_shape = (2, 60) + self.new_shape = (12, 10) + self.infered_shape = (12, 10) + + def test_check_output(self): + self.check_output(no_check_set=['XShape']) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReshape2OneDNNOpDimInfer1(TestReshape2OneDNNOp): + def init_data(self): + self.ori_shape = (5, 25) + self.new_shape = (5, -1, 5) + self.infered_shape = (5, -1, 5) + + +class TestReshape2OneDNNOpDimInfer2(TestReshape2OneDNNOp): + def init_data(self): + self.ori_shape = (10, 2, 6) + self.new_shape = (10, 0, 3, -1) + self.infered_shape = (10, 2, 3, -1) + + def set_additional_inputs(self): + self.inputs["Shape"] = np.array(self.actual_shape, dtype="int32") + + def set_outputs(self): + self.outputs = { + "Out": self.inputs["X"].reshape(self.actual_shape), + 'XShape': np.random.random(self.ori_shape).astype("float32") + } + + def init_data(self): + self.ori_shape = (6, 20) + self.new_shape = (0, -1, 20) + self.actual_shape = (2, 3, 20) + + +class TestReshape2OneDNNOp_attr_OnlyShape(TestReshape2OneDNNOp): + def set_additional_inputs(self): + self.inputs["Shape"] = np.array(self.new_shape, dtype="int32") + + def set_attrs(self): + self.attrs = {'use_mkldnn': True} + + def set_outputs(self): + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + 'XShape': np.random.random(self.ori_shape).astype("float32") + } + + def init_data(self): + self.ori_shape = (4, 25) + self.new_shape = (10, 10) + self.infered_shape = (10, 10) + + +class TestReshape2OneDNNOpDimInfer1_attr_OnlyShape( + TestReshape2OneDNNOp_attr_OnlyShape): + def init_data(self): + self.ori_shape = (5, 20) + self.new_shape = (5, -1, 10) + self.infered_shape = (5, -1, 10) + self.shape = (5, -1, -1) + + +class TestReshapeOneDNNOp(TestReshape2OneDNNOp): + def set_op_type(self): + self.op_type = "reshape" + + def set_outputs(self): + self.outputs = {"Out": self.inputs["X"].reshape(self.infered_shape)} + + def test_check_output(self): + self.check_output() + + +class TestReshapeOneDNNOpDimInfer1(TestReshapeOneDNNOp): + def init_data(self): + self.ori_shape = (5, 25) + self.new_shape = (5, -1, 5) + self.infered_shape = (5, -1, 5) + + +class TestReshapeOneDNNOp_attr_OnlyShape(TestReshape2OneDNNOp_attr_OnlyShape): + def set_op_type(self): + self.op_type = "reshape" + + def set_outputs(self): + self.outputs = {"Out": self.inputs["X"].reshape(self.infered_shape)} + + def test_check_output(self): + self.check_output() + + +class TestReshapeOneDNNOpDimInfer1_attr_OnlyShape( + TestReshapeOneDNNOp_attr_OnlyShape): + def init_data(self): + self.ori_shape = (5, 20) + self.new_shape = (5, -1, 10) + self.infered_shape = (5, -1, 10) + self.shape = (5, -1, -1) + + +# BF16 TESTS +def create_reshape_bf16_test_classes(parent): + @OpTestTool.skip_if_not_cpu_bf16() + class TestReshape2BF16OneDNNOp(parent): + def set_inputs(self): + self.dtype = np.uint16 + self.inputs = {"X": convert_float_to_uint16(self.x)} + + def calculate_grads(self): + self.dout = self.outputs['Out'] + self.dx = np.reshape(self.dout, self.ori_shape) + + def test_check_output(self): + self.check_output_with_place( + core.CPUPlace(), no_check_set=["XShape"]) + + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + user_defined_grads=[self.dx], + user_defined_grad_outputs=[self.dout]) + + cls_name = "{0}_{1}".format(parent.__name__, "Reshape2_BF16") + TestReshape2BF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestReshape2BF16OneDNNOp + + class TestReshapeBF16OneDNNOp(TestReshape2BF16OneDNNOp): + def set_op_type(self): + self.dtype = np.uint16 + self.op_type = "reshape" + + def set_outputs(self): + self.outputs = {"Out": self.x.reshape(self.new_shape)} + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + user_defined_grads=[self.dx], + user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) + + cls_name = "{0}_{1}".format(parent.__name__, "Reshape_BF16") + TestReshapeBF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestReshapeBF16OneDNNOp + + +create_reshape_bf16_test_classes(TestReshape2OneDNNOp) +create_reshape_bf16_test_classes(TestReshape2OneDNNOpDimInfer1) + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py new file mode 100644 index 0000000000000..489d851038042 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 + + +@OpTestTool.skip_if(core.is_compiled_with_cuda(), + "CUDA has to be skipped because it forces dygraph") +class TestSqueeze2OneDNNOp(OpTest): + def set_op_type(self): + self.op_type = "squeeze2" + + def init_test_case(self): + self.ori_shape = (1, 3, 1, 40) + self.axes = (0, 2) + self.new_shape = (3, 40) + + def set_inputs(self): + self.inputs = {"X": self.x} + + def init_attrs(self): + self.attrs = {"axes": self.axes, 'use_mkldnn': True} + + def set_outputs(self): + self.outputs = { + "Out": self.x.reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float32") + } + + def setUp(self): + self.set_op_type() + self.init_test_case() + self.x = np.random.random(self.ori_shape).astype("float32") + self.set_inputs() + self.init_attrs() + self.set_outputs() + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace(), no_check_set=['XShape']) + + def test_check_grad(self): + self.check_grad_with_place(core.CPUPlace(), ["X"], "Out") + + +class TestSqueezeOneDNNOp(TestSqueeze2OneDNNOp): + def set_op_type(self): + self.op_type = "squeeze" + + def set_outputs(self): + self.outputs = {"Out": self.x.reshape(self.new_shape)} + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + +class TestSqueeze2OneDNNOp1(TestSqueeze2OneDNNOp): + def init_test_case(self): + self.ori_shape = (1, 20, 1, 5) + self.axes = (0, -2) + self.new_shape = (20, 5) + + +class TestSqueezeOneDNNOp1(TestSqueezeOneDNNOp): + def init_test_case(self): + self.ori_shape = (1, 20, 1, 5) + self.axes = (0, -2) + self.new_shape = (20, 5) + + +class TestSqueeze2OneDNNOp2(TestSqueeze2OneDNNOp): + def init_test_case(self): + self.ori_shape = (1, 20, 1, 5) + self.axes = () + self.new_shape = (20, 5) + + +class TestSqueezeOneDNNOp2(TestSqueezeOneDNNOp): + def init_test_case(self): + self.ori_shape = (1, 20, 1, 5) + self.axes = () + self.new_shape = (20, 5) + + +class TestSqueeze2OneDNNOp3(TestSqueeze2OneDNNOp): + def init_test_case(self): + self.ori_shape = (25, 1, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (25, 1, 4) + + +class TestSqueezeOneDNNOp3(TestSqueezeOneDNNOp): + def init_test_case(self): + self.ori_shape = (25, 1, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (25, 1, 4) + + +# BF16 TESTS +def create_squeeze_bf16_test_classes(parent): + @OpTestTool.skip_if_not_cpu_bf16() + class TestSqueeze2BF16OneDNNOp(parent): + def set_inputs(self): + self.dtype = np.uint16 + self.inputs = {"X": convert_float_to_uint16(self.x)} + + def calculate_grads(self): + self.dout = self.outputs['Out'] + self.dx = np.reshape(self.dout, self.ori_shape) + + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + user_defined_grads=[self.dx], + user_defined_grad_outputs=[self.dout]) + + cls_name = "{0}_{1}".format(parent.__name__, "Squeeze2_BF16") + TestSqueeze2BF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestSqueeze2BF16OneDNNOp + + class TestSqueezeBF16OneDNNOp(TestSqueeze2BF16OneDNNOp): + def set_op_type(self): + self.dtype = np.uint16 + self.op_type = "squeeze" + + def set_outputs(self): + self.outputs = {"Out": self.x.reshape(self.new_shape)} + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + cls_name = "{0}_{1}".format(parent.__name__, "Squeeze_BF16") + TestSqueezeBF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestSqueezeBF16OneDNNOp + + +create_squeeze_bf16_test_classes(TestSqueeze2OneDNNOp) +create_squeeze_bf16_test_classes(TestSqueeze2OneDNNOp1) +create_squeeze_bf16_test_classes(TestSqueeze2OneDNNOp2) +create_squeeze_bf16_test_classes(TestSqueeze2OneDNNOp3) + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()