From db5fd2a1925cbfa710c998364818f1d191f088a4 Mon Sep 17 00:00:00 2001 From: will-jl944 Date: Wed, 8 Sep 2021 15:52:42 +0800 Subject: [PATCH] multiply supports bool multiply supports bool --- .../operators/elementwise/elementwise_mul_op.cc | 4 ++++ .../operators/elementwise/elementwise_mul_op.cu | 3 +++ python/paddle/fluid/tests/unittests/test_multiply.py | 12 ++++++++++++ python/paddle/tensor/math.py | 8 ++++---- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 0045f00ecc6c2..21d1ebddbd459 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -132,6 +132,7 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel, ops::ElementwiseMulKernel>, ops::ElementwiseMulKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel>, ops::ElementwiseMulGradKernel, ops::ElementwiseMulDoubleGradKernel, + ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel>, ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel>, ops::ElementwiseMulKernel>); @@ -130,6 +131,7 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel>, @@ -141,6 +143,7 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel, + ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel>, diff --git a/python/paddle/fluid/tests/unittests/test_multiply.py b/python/paddle/fluid/tests/unittests/test_multiply.py index b839272ccf092..3fd6e3f0c865a 100755 --- a/python/paddle/fluid/tests/unittests/test_multiply.py +++ b/python/paddle/fluid/tests/unittests/test_multiply.py @@ -70,6 +70,12 @@ def test_multiply(self): res = self._run_static_graph_case(x_data, y_data) self.assertTrue(np.allclose(res, np.multiply(x_data, y_data))) + # test static computation graph: boolean + x_data = np.random.choice([True, False], size=[200]) + y_data = np.random.choice([True, False], size=[200]) + res = self._run_static_graph_case(x_data, y_data) + self.assertTrue(np.allclose(res, np.multiply(x_data, y_data))) + # test dynamic computation graph: 1-d array x_data = np.random.rand(200) y_data = np.random.rand(200) @@ -88,6 +94,12 @@ def test_multiply(self): res = self._run_dynamic_graph_case(x_data, y_data) self.assertTrue(np.allclose(res, np.multiply(x_data, y_data))) + # test dynamic computation graph: boolean + x_data = np.random.choice([True, False], size=[200]) + y_data = np.random.choice([True, False], size=[200]) + res = self._run_dynamic_graph_case(x_data, y_data) + self.assertTrue(np.allclose(res, np.multiply(x_data, y_data))) + class TestMultiplyError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e73c97ee0f0d8..29f3425cb7687 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -196,10 +196,10 @@ def _elementwise_op(helper): assert x is not None, 'x cannot be None in {}'.format(original_op_type) assert y is not None, 'y cannot be None in {}'.format(original_op_type) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], original_op_type) check_variable_and_dtype( - y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], + y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], original_op_type) axis = helper.kwargs.get('axis', -1) @@ -473,8 +473,8 @@ def multiply(x, y, name=None): ``paddle.multiply`` supports broadcasting. If you would like to know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . Args: - x (Tensor): the input tensor, its data type should be float32, float64, int32, int64. - y (Tensor): the input tensor, its data type should be float32, float64, int32, int64. + x (Tensor): the input tensor, its data type should be one of float32, float64, int32, int64, bool. + y (Tensor): the input tensor, its data type should be one of float32, float64, int32, int64, bool. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: