From c54d0db9118ba167ad6b21c9c8f9244459868a85 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Fri, 20 Jan 2023 18:40:24 +0000 Subject: [PATCH 1/3] add dimension check in flip --- python/paddle/tensor/manipulation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index bdd903ee8f196..617f5297083f2 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1329,6 +1329,8 @@ def flip(x, axis, name=None): if isinstance(axis, int): axis = [axis] + assert np.array(axis).ndim == 1 + if in_dygraph_mode(): return _C_ops.flip(x, axis) else: From 323acbc969d4a452f0aaffc3b37638f0678f0957 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Sun, 22 Jan 2023 10:09:29 +0000 Subject: [PATCH 2/3] raise ValueError when ndim incorrect --- python/paddle/tensor/manipulation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 617f5297083f2..9301715d1f8a5 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1329,7 +1329,8 @@ def flip(x, axis, name=None): if isinstance(axis, int): axis = [axis] - assert np.array(axis).ndim == 1 + if np.array(axis).ndim != 1: + raise ValueError('The axis of flip must be a list, tuple or int.') if in_dygraph_mode(): return _C_ops.flip(x, axis) From 108965ab5ea9c2204ac2353daf5b0e2e372df3d8 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Sun, 22 Jan 2023 12:57:04 +0000 Subject: [PATCH 3/3] add unittest axis rank for flip --- python/paddle/fluid/tests/unittests/test_flip.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_flip.py b/python/paddle/fluid/tests/unittests/test_flip.py index 1807199821eb7..0856d8dec2940 100644 --- a/python/paddle/fluid/tests/unittests/test_flip.py +++ b/python/paddle/fluid/tests/unittests/test_flip.py @@ -198,6 +198,17 @@ def test_grad(self): self.func(p) +class TestFlipError(unittest.TestCase): + def test_axis(self): + paddle.enable_static() + + def test_axis_rank(): + input = fluid.data(name='input', dtype='float32', shape=[2, 3]) + output = paddle.flip(input, axis=[[0]]) + + self.assertRaises(ValueError, test_axis_rank) + + if __name__ == "__main__": paddle.enable_static() unittest.main()