Skip to content

Commit

Permalink
【Hackathon 5th No.13】【关联 PR】Added int support for sign -Part (PaddleP…
Browse files Browse the repository at this point in the history
…addle#58255)

* ♻️ Refactor: added sign int type support

* ✏️ Refactor: update typo
  • Loading branch information
PommesPeter authored Oct 20, 2023
1 parent 7b74967 commit 2195862
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 41 deletions.
11 changes: 10 additions & 1 deletion paddle/phi/kernels/cpu/sign_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,13 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/phi/common/bfloat16.h"

PD_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, phi::SignKernel, float, double) {}
PD_REGISTER_KERNEL(sign,
CPU,
ALL_LAYOUT,
phi::SignKernel,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double) {}
4 changes: 4 additions & 0 deletions paddle/phi/kernels/funcs/eigen/sign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ struct EigenSign<Eigen::DefaultDevice, T> {
}
};

template struct EigenSign<Eigen::DefaultDevice, int8_t>;
template struct EigenSign<Eigen::DefaultDevice, int16_t>;
template struct EigenSign<Eigen::DefaultDevice, int32_t>;
template struct EigenSign<Eigen::DefaultDevice, int64_t>;
template struct EigenSign<Eigen::DefaultDevice, float>;
template struct EigenSign<Eigen::DefaultDevice, double>;

Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/funcs/eigen/sign.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ struct EigenSign<Eigen::GpuDevice, T> {
}
};

template struct EigenSign<Eigen::GpuDevice, int8_t>;
template struct EigenSign<Eigen::GpuDevice, int16_t>;
template struct EigenSign<Eigen::GpuDevice, int32_t>;
template struct EigenSign<Eigen::GpuDevice, int64_t>;
template struct EigenSign<Eigen::GpuDevice, float>;
template struct EigenSign<Eigen::GpuDevice, double>;
template struct EigenSign<Eigen::GpuDevice, dtype::float16>;
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/gpu/sign_kernel.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ PD_REGISTER_KERNEL(sign,
GPU,
ALL_LAYOUT,
phi::SignKernel,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
phi::dtype::float16,
Expand Down
16 changes: 14 additions & 2 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4565,7 +4565,7 @@ def sign(x, name=None):
Returns sign of every element in `x`: 1 for positive, -1 for negative and 0 for zero.
Args:
x (Tensor): The input tensor. The data type can be float16, float32 or float64.
x (Tensor): The input tensor. The data type can be int8, int16, int32, int64, float16, float32 or float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand All @@ -4586,7 +4586,19 @@ def sign(x, name=None):
return _C_ops.sign(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'sign'
x,
'x',
[
'int8',
'int16',
'int32',
'int64',
'float16',
'float32',
'float64',
'uint16',
],
'sign',
)
helper = LayerHelper("sign", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down
88 changes: 50 additions & 38 deletions test/legacy_test/test_sign_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,12 @@ def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')


class TestSignOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of sign_op must be Variable or numpy.ndarray.
input1 = 12
self.assertRaises(TypeError, paddle.sign, input1)
# The input dtype of sign_op must be float16, float32, float64.
input2 = paddle.static.data(
name='input2', shape=[-1, 12, 10], dtype="int32"
)
input3 = paddle.static.data(
name='input3', shape=[-1, 12, 10], dtype="int64"
)
self.assertRaises(TypeError, paddle.sign, input2)
self.assertRaises(TypeError, paddle.sign, input3)
input4 = paddle.static.data(
name='input4', shape=[-1, 4], dtype="float16"
)
paddle.sign(input4)


class TestSignAPI(unittest.TestCase):
def setUp(self):
self.place = [base.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(base.CUDAPlace(0))

def test_dygraph(self):
with base.dygraph.guard():
np_x = np.array([-1.0, 0.0, -0.0, 1.2, 1.5], dtype='float64')
Expand All @@ -108,23 +92,51 @@ def test_dygraph(self):
self.assertEqual((np_z == z_expected).all(), True)

def test_static(self):
with program_guard(Program(), Program()):
# The input type of sign_op must be Variable or numpy.ndarray.
input1 = 12
self.assertRaises(TypeError, paddle.tensor.math.sign, input1)
# The input dtype of sign_op must be float16, float32, float64.
input2 = paddle.static.data(
name='input2', shape=[-1, 12, 10], dtype="int32"
)
input3 = paddle.static.data(
name='input3', shape=[-1, 12, 10], dtype="int64"
)
self.assertRaises(TypeError, paddle.tensor.math.sign, input2)
self.assertRaises(TypeError, paddle.tensor.math.sign, input3)
input4 = paddle.static.data(
name='input4', shape=[-1, 4], dtype="float16"
)
paddle.sign(input4)
np_input2 = np.random.uniform(-10, 10, (12, 10)).astype("int16")
np_input3 = np.random.uniform(-10, 10, (12, 10)).astype("int32")
np_input4 = np.random.uniform(-10, 10, (12, 10)).astype("int64")
np_out2 = np.sign(np_input2)
np_out3 = np.sign(np_input3)
np_out4 = np.sign(np_input4)

def run(place):
with program_guard(Program(), Program()):
# The input type of sign_op must be Variable or numpy.ndarray.
input1 = 12
self.assertRaises(TypeError, paddle.tensor.math.sign, input1)
# The result of sign_op must correct.
input2 = paddle.static.data(
name='input2', shape=[12, 10], dtype="int16"
)
input3 = paddle.static.data(
name='input3', shape=[12, 10], dtype="int32"
)
input4 = paddle.static.data(
name='input4', shape=[12, 10], dtype="int64"
)
out2 = paddle.sign(input2)
out3 = paddle.sign(input3)
out4 = paddle.sign(input4)
exe = paddle.static.Executor(place)
res2, res3, res4 = exe.run(
paddle.static.default_main_program(),
feed={
"input2": np_input2,
"input3": np_input3,
"input4": np_input4,
},
fetch_list=[out2, out3, out4],
)
self.assertEqual((res2 == np_out2).all(), True)
self.assertEqual((res3 == np_out3).all(), True)
self.assertEqual((res4 == np_out4).all(), True)
input5 = paddle.static.data(
name='input5', shape=[-1, 4], dtype="float16"
)
paddle.sign(input5)

for place in self.place:
run(place)


class TestSignDoubleGradCheck(unittest.TestCase):
Expand Down

0 comments on commit 2195862

Please sign in to comment.