diff --git a/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc index d866274aac237..ff19df8eebf1f 100644 --- a/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc @@ -32,7 +32,7 @@ class PReluMKLDNNHandler PReluMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* weights, - const std::string& uniq_name) + const std::string& uniq_name, bool is_test) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), @@ -46,8 +46,9 @@ class PReluMKLDNNHandler this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, x_md, weights_md); - this->AcquireBackwardPrimitiveDescriptor(x_md, weights_md, x_md, - weights_md); + if (!is_test) + this->AcquireBackwardPrimitiveDescriptor(x_md, weights_md, x_md, + weights_md); } } @@ -87,7 +88,7 @@ class PReluMKLDNNKernel : public framework::OpKernel { const bool is_test = ctx.Attr("is_test"); PReluMKLDNNHandler handler(dev_ctx, onednn_engine, ctx.GetPlace(), x, - alpha, ctx.InputName("X")); + alpha, ctx.InputName("X"), is_test); auto src_memory_p = handler.AcquireSrcMemory(x); auto weights_memory_p = diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py index ea08487a3f8a1..d5ebc4e274ead 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py @@ -132,7 +132,6 @@ def test_check_grad(self): self.check_grad_with_place( core.CPUPlace(), ["X", "Alpha"], "Out", - check_dygraph=False, user_defined_grads=[self.dx, self.dalpha], user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 4f78eceee4f15..5ac607018856f 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1436,6 +1436,9 @@ def check_grad_with_place(self, op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() + if self.is_bfloat16_op(): + check_dygraph = False + self._check_grad_helper() if self.dtype == np.float64 and \ self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST: