Skip to content

Commit

Permalink
multiply supports bool
Browse files Browse the repository at this point in the history
multiply supports bool
  • Loading branch information
will-jl944 authored Sep 8, 2021
1 parent a2dbb0c commit db5fd2a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 4 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
Expand All @@ -142,6 +143,7 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
Expand All @@ -156,6 +158,8 @@ REGISTER_OP_CPU_KERNEL(
int>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
bool>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
Expand All @@ -130,6 +131,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
Expand All @@ -141,6 +143,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/fluid/tests/unittests/test_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def test_multiply(self):
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))

# test static computation graph: boolean
x_data = np.random.choice([True, False], size=[200])
y_data = np.random.choice([True, False], size=[200])
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))

# test dynamic computation graph: 1-d array
x_data = np.random.rand(200)
y_data = np.random.rand(200)
Expand All @@ -88,6 +94,12 @@ def test_multiply(self):
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))

# test dynamic computation graph: boolean
x_data = np.random.choice([True, False], size=[200])
y_data = np.random.choice([True, False], size=[200])
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))


class TestMultiplyError(unittest.TestCase):
def test_errors(self):
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(original_op_type)
assert y is not None, 'y cannot be None in {}'.format(original_op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'],
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
original_op_type)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'],
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
original_op_type)

axis = helper.kwargs.get('axis', -1)
Expand Down Expand Up @@ -473,8 +473,8 @@ def multiply(x, y, name=None):
``paddle.multiply`` supports broadcasting. If you would like to know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, its data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, its data type should be float32, float64, int32, int64.
x (Tensor): the input tensor, its data type should be one of float32, float64, int32, int64, bool.
y (Tensor): the input tensor, its data type should be one of float32, float64, int32, int64, bool.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand Down

0 comments on commit db5fd2a

Please sign in to comment.