diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index fe4044a9bb4dd..3b9d817522561 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -253,7 +253,15 @@ def set_dtype_attr(self): def test_check_output(self): self.check_output_with_place(core.CPUPlace()) - def tmttml(self, x, transpose_x, y, transpose_y): + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X", "Y"], + "Out", + user_defined_grads=[self.dx, self.dy], + user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) + + def matmul_grad(self, x, transpose_x, y, transpose_y): x = np.transpose( x, self.shape_transpose_axes[x.ndim]) if transpose_x else x y = np.transpose( @@ -296,19 +304,19 @@ def calculate_grads(self): is_broadcast = x.shape[0:-2] != y.shape[0:-2] if self.attrs['trans_x'] is True and self.attrs['trans_y'] is True: - self.dx = self.tmttml(self.y_fp32, True, dout, True) - self.dy = self.tmttml(dout, True, self.x_fp32, True) + self.dx = self.matmul_grad(self.y_fp32, True, dout, True) + self.dy = self.matmul_grad(dout, True, self.x_fp32, True) elif self.attrs['trans_x'] is True and self.attrs[ 'trans_y'] is False: - self.dx = self.tmttml(self.y_fp32, False, dout, True) - self.dy = self.tmttml(self.x_fp32, False, dout, False) + self.dx = self.matmul_grad(self.y_fp32, False, dout, True) + self.dy = self.matmul_grad(self.x_fp32, False, dout, False) elif self.attrs['trans_x'] is False and self.attrs[ 'trans_y'] is True: - self.dx = self.tmttml(dout, False, self.y_fp32, False) - self.dy = self.tmttml(dout, True, self.x_fp32, False) + self.dx = self.matmul_grad(dout, False, self.y_fp32, False) + self.dy = self.matmul_grad(dout, True, self.x_fp32, False) else: - self.dx = self.tmttml(dout, False, self.y_fp32, True) - self.dy = self.tmttml(self.x_fp32, True, dout, False) + self.dx = self.matmul_grad(dout, False, self.y_fp32, True) + self.dy = self.matmul_grad(self.x_fp32, True, dout, False) if is_broadcast: x_reduce_axis = [] @@ -340,14 +348,6 @@ def calculate_grads(self): self.dout = dout - def test_check_grad(self): - self.calculate_grads() - self.check_grad_with_place( - core.CPUPlace(), ["X", "Y"], - "Out", - user_defined_grads=[self.dx, self.dy], - user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) - cls_name = "{0}_{1}".format(parent.__name__, "BF16") TestMatMulV2Bf16OneDNNOp.__name__ = cls_name globals()[cls_name] = TestMatMulV2Bf16OneDNNOp