Skip to content

Commit

Permalink
support int32 and int64 kernel for clip operator (#32373)
Browse files Browse the repository at this point in the history
support int32 and int64 kernel for clip operator
  • Loading branch information
wuyefeilin authored Apr 22, 2021
1 parent a1a527f commit c332828
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
8 changes: 6 additions & 2 deletions paddle/fluid/operators/clip_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,14 @@ REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer,
ops::ClipDoubleGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>,
ops::ClipKernel<paddle::platform::CPUDeviceContext, double>);
ops::ClipKernel<paddle::platform::CPUDeviceContext, double>,
ops::ClipKernel<paddle::platform::CPUDeviceContext, int>,
ops::ClipKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ClipGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::ClipGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ClipGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ClipGradKernel<paddle::platform::CPUDeviceContext, int64_t>);

REGISTER_OP_VERSION(clip)
.AddCheckpoint(
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/operators/clip_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
clip, ops::ClipKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, double>);
ops::ClipKernel<paddle::platform::CUDADeviceContext, double>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, int>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, int64_t>);

REGISTER_OP_CUDA_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
28 changes: 26 additions & 2 deletions python/paddle/fluid/tests/unittests/test_clip_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ def setUp(self):
self.outputs = {'Out': np.clip(self.inputs['X'], min_v, max_v)}

def test_check_output(self):
paddle.enable_static()
self.check_output()
paddle.disable_static()

def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad(['X'], 'Out')
paddle.disable_static()

def initTestCase(self):
self.shape = (4, 10, 10)
Expand Down Expand Up @@ -102,6 +106,7 @@ def initTestCase(self):

class TestClipOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
input_data = np.random.random((2, 4)).astype("float32")

Expand All @@ -115,6 +120,7 @@ def test_dtype():
fluid.layers.clip(x=x2, min=-1.0, max=1.0)

self.assertRaises(TypeError, test_dtype)
paddle.disable_static()


class TestClipAPI(unittest.TestCase):
Expand All @@ -140,15 +146,19 @@ def test_clip(self):
out_8 = paddle.clip(images)
out_9 = paddle.clip(paddle.cast(images, 'float64'), min=0.2, max=0.9)

res1, res2, res3, res4, res5, res6, res7, res8, res9 = exe.run(
out_10 = paddle.clip(paddle.cast(images * 10, 'int32'), min=2, max=8)
out_11 = paddle.clip(paddle.cast(images * 10, 'int64'), min=2, max=8)

res1, res2, res3, res4, res5, res6, res7, res8, res9, res10, res11 = exe.run(
fluid.default_main_program(),
feed={
"image": data,
"min": np.array([0.2]).astype('float32'),
"max": np.array([0.8]).astype('float32')
},
fetch_list=[
out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8, out_9
out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8, out_9,
out_10, out_11
])

self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8)))
Expand All @@ -161,8 +171,14 @@ def test_clip(self):
self.assertTrue(np.allclose(res8, data))
self.assertTrue(
np.allclose(res9, data.astype(np.float64).clip(0.2, 0.9)))
self.assertTrue(
np.allclose(res10, (data * 10).astype(np.int32).clip(2, 8)))
self.assertTrue(
np.allclose(res11, (data * 10).astype(np.int64).clip(2, 8)))
paddle.disable_static()

def test_clip_dygraph(self):
paddle.disable_static()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
paddle.disable_static(place)
Expand All @@ -176,16 +192,24 @@ def test_clip_dygraph(self):
out_2 = paddle.clip(images, min=0.2, max=0.9)
out_3 = paddle.clip(images, min=v_min, max=v_max)

out_4 = paddle.clip(paddle.cast(images * 10, 'int32'), min=2, max=8)
out_5 = paddle.clip(paddle.cast(images * 10, 'int64'), min=2, max=8)

self.assertTrue(np.allclose(out_1.numpy(), data.clip(0.2, 0.8)))
self.assertTrue(np.allclose(out_2.numpy(), data.clip(0.2, 0.9)))
self.assertTrue(np.allclose(out_3.numpy(), data.clip(0.2, 0.8)))
self.assertTrue(
np.allclose(out_4.numpy(), (data * 10).astype(np.int32).clip(2, 8)))
self.assertTrue(
np.allclose(out_5.numpy(), (data * 10).astype(np.int64).clip(2, 8)))

def test_errors(self):
paddle.enable_static()
x1 = fluid.data(name='x1', shape=[1], dtype="int16")
x2 = fluid.data(name='x2', shape=[1], dtype="int8")
self.assertRaises(TypeError, paddle.clip, x=x1, min=0.2, max=0.8)
self.assertRaises(TypeError, paddle.clip, x=x2, min=0.2, max=0.8)
paddle.disable_static()


if __name__ == '__main__':
Expand Down
26 changes: 17 additions & 9 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,10 +1475,10 @@ def clip(x, min=None, max=None, name=None):
Out = MIN(MAX(x, min), max)
Args:
x (Tensor): An N-D Tensor with data type float32 or float64.
min (float32|Tensor): The lower bound with type ``float32`` or a ``Tensor``
x (Tensor): An N-D Tensor with data type float32, float64, int32 or int64.
min (float|int|Tensor): The lower bound with type ``float`` , ``int`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``.
max (float32|Tensor): The upper bound with type ``float32`` or a ``Tensor``
max (float|int|Tensor): The upper bound with type ``float``, ``int`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
Expand All @@ -1503,16 +1503,24 @@ def clip(x, min=None, max=None, name=None):
# [[4.5, 6.4]
"""

fmin = float(np.finfo(np.float32).min)
fmax = float(np.finfo(np.float32).max)
x_dtype = str(x.dtype)
if x_dtype == 'paddle.int32':
min_ = np.iinfo(np.int32).min
max_ = np.iinfo(np.int32).max - 2**7
elif x_dtype == 'paddle.int64':
min_ = np.iinfo(np.int64).min
max_ = np.iinfo(np.int64).max - 2**39
else:
min_ = float(np.finfo(np.float32).min)
max_ = float(np.finfo(np.float32).max)

if in_dygraph_mode():
if isinstance(min, Variable):
min = min.numpy().item(0)
if isinstance(max, Variable):
max = max.numpy().item(0)
min = fmin if min is None else min
max = fmax if max is None else max
min = min_ if min is None else min
max = max_ if max is None else max
return core.ops.clip(x, "min", min, "max", max)

if min is not None:
Expand All @@ -1526,10 +1534,10 @@ def clip(x, min=None, max=None, name=None):
check_dtype(max.dtype, 'max', ['float32', 'float64', 'int32'],
'clip', '(When the type of max in clip is Variable.)')

check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'clip')
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], 'clip')

inputs = {'X': x}
attrs = {'min': fmin, 'max': fmax}
attrs = {'min': min_, 'max': max_}

if isinstance(min, Variable):
min.stop_gradient = True
Expand Down

0 comments on commit c332828

Please sign in to comment.