From a9f76d0715daa674f9cfe06b13446b7af37d5d93 Mon Sep 17 00:00:00 2001 From: Chenxiao Niu Date: Thu, 28 Jul 2022 13:53:50 +0800 Subject: [PATCH] [MLU] fix log_softmax mode selection. (#44669) --- paddle/fluid/operators/softmax_op_mlu.cc | 7 ++-- .../unittests/mlu/test_log_softmax_op_mlu.py | 35 +++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/softmax_op_mlu.cc b/paddle/fluid/operators/softmax_op_mlu.cc index b2e7410136c13..50ef6c6599294 100644 --- a/paddle/fluid/operators/softmax_op_mlu.cc +++ b/paddle/fluid/operators/softmax_op_mlu.cc @@ -117,10 +117,9 @@ REGISTER_OP_MLU_KERNEL(softmax_grad, ops::SoftmaxGradMLUKernel, ops::SoftmaxGradMLUKernel); -REGISTER_OP_MLU_KERNEL( - log_softmax, - ops::SoftmaxMLUKernel, - ops::SoftmaxMLUKernel); +REGISTER_OP_MLU_KERNEL(log_softmax, + ops::SoftmaxMLUKernel, + ops::SoftmaxMLUKernel); REGISTER_OP_MLU_KERNEL( log_softmax_grad, ops::SoftmaxGradMLUKernel, diff --git a/python/paddle/fluid/tests/unittests/mlu/test_log_softmax_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_log_softmax_op_mlu.py index a1d594b93d01d..1b81455f47797 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_log_softmax_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_log_softmax_op_mlu.py @@ -86,6 +86,41 @@ def set_attrs(self): self.axis = 1 +class TestLogSoftmaxOpFp16(OpTest): + + def setUp(self): + self.op_type = 'log_softmax' + self.set_mlu() + self.python_api = F.log_softmax + self.dtype = 'float16' + self.shape = [2, 3, 4, 5] + self.axis = -1 + self.set_attrs() + + x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype) + out = np.apply_along_axis(ref_log_softmax, self.axis, x) + self.x_grad = ref_log_softmax_grad(x, self.axis) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'axis': self.axis} + + def set_attrs(self): + pass + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-2) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], ['Out'], + user_defined_grads=[self.x_grad], + max_relative_error=0.015) + + class TestNNLogSoftmaxAPI(unittest.TestCase): def setUp(self):