Skip to content

Commit

Permalink
update v1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
jjyaoao committed Nov 2, 2023
1 parent 3171348 commit f22b939
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 68 deletions.
22 changes: 3 additions & 19 deletions paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ PD_REGISTER_KERNEL(equal_all,
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int8_t, \
int16_t, \
int64_t, \
float, \
double, \
Expand All @@ -120,26 +121,9 @@ PD_REGISTER_KERNEL(equal_all,
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual)
PD_REGISTER_COMPARE_KERNEL(equal, Equal)
PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual)

#define PD_REGISTER_LESS_THAN_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int8_t, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
PD_REGISTER_LESS_THAN_KERNEL(less_than, LessThan)
23 changes: 3 additions & 20 deletions paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ PD_REGISTER_KERNEL(equal_all,
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int8_t, \
int16_t, \
int64_t, \
float, \
double, \
Expand All @@ -160,29 +161,11 @@ PD_REGISTER_KERNEL(equal_all,
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual)
PD_REGISTER_COMPARE_KERNEL(equal, Equal)
PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual)

#define PD_REGISTER_LESS_THAN_KERNEL(func) \
PD_REGISTER_KERNEL(less_than, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int8_t, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

PD_REGISTER_LESS_THAN_KERNEL(less_than, LessThan)

#endif
3 changes: 2 additions & 1 deletion paddle/phi/kernels/legacy/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT,
phi::LessThanRawKernel,
bool,
int16_t,
int8_t int16_t,
int,
int64_t,
float,
Expand All @@ -131,6 +131,7 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
int8_t, \
int16_t, \
int, \
int64_t, \
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/legacy/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT,
phi::LessThanRawKernel,
bool,
int8_t,
int16_t,
int,
int64_t,
Expand All @@ -157,6 +158,7 @@ PD_REGISTER_KERNEL(less_than_raw,
bool, \
int16_t, \
int, \
int8_t, \
int64_t, \
float, \
double, \
Expand Down
46 changes: 34 additions & 12 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,8 @@ def equal(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64.
y (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64.
x (Tensor): Tensor, data type is bool, float16, float32, float64, int8, int16, int32, int64.
y (Tensor): Tensor, data type is bool, float16, float32, float64, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -553,6 +553,8 @@ def equal(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -567,6 +569,8 @@ def equal(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -611,8 +615,8 @@ def greater_equal(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand Down Expand Up @@ -641,6 +645,8 @@ def greater_equal(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -655,6 +661,8 @@ def greater_equal(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -699,8 +707,8 @@ def greater_than(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand Down Expand Up @@ -729,6 +737,8 @@ def greater_than(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -743,6 +753,8 @@ def greater_than(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -787,8 +799,8 @@ def less_equal(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -818,6 +830,8 @@ def less_equal(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -832,6 +846,8 @@ def less_equal(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -876,8 +892,8 @@ def less_than(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -908,6 +924,7 @@ def less_than(x, y, name=None):
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -923,6 +940,7 @@ def less_than(x, y, name=None):
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -967,8 +985,8 @@ def not_equal(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -998,6 +1016,8 @@ def not_equal(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -1012,6 +1032,8 @@ def not_equal(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down
25 changes: 9 additions & 16 deletions test/legacy_test/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ def test_errors(self):
globals()[cls_name] = Cls


for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}:
for _type_name in {
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'float16',
}:
if _type_name == 'float64' and core.is_compiled_with_rocm():
_type_name = 'float32'
if _type_name == 'float16' and (not core.is_compiled_with_cuda()):
Expand Down Expand Up @@ -615,21 +623,6 @@ def test_place_2(self):
self.assertEqual((result.numpy() == np.array([False])).all(), True)


class TestLessThanInt8(unittest.TestCase):
def test_less_than_int8(self):
# Create a tensor of type int8
x = paddle.to_tensor([1, 2, 3], dtype='int8')
y = paddle.to_tensor([1, 3, 2], dtype='int8')

result = paddle.less_than(x, y)

# desired output
expected = np.array([False, True, False])

# Verify output
self.assertTrue((result.numpy() == expected).all())


if __name__ == '__main__':
paddle.enable_static()
unittest.main()

0 comments on commit f22b939

Please sign in to comment.