diff --git a/python/paddle/base/data_feeder.py b/python/paddle/base/data_feeder.py index 6553338aea590..b629faf5cacc9 100644 --- a/python/paddle/base/data_feeder.py +++ b/python/paddle/base/data_feeder.py @@ -229,18 +229,22 @@ def check_dtype( def check_shape( shape, op_name, - expected_shape_type=(list, tuple, Variable), - expected_element_type=(int, Variable), + expected_shape_type=(list, tuple, Variable, Value), + expected_element_type=(int, Variable, Value), expected_tensor_dtype=('int32', 'int64'), ): # See NOTE [ Why skip dynamic graph check ] if in_dygraph_mode(): return check_type(shape, 'shape', expected_shape_type, op_name) - if expected_element_type is not None and not isinstance(shape, Variable): + if expected_element_type is not None and not isinstance( + shape, (Variable, Value) + ): for item in shape: check_type(item, 'element of shape', expected_element_type, op_name) - if expected_tensor_dtype is not None and isinstance(item, Variable): + if expected_tensor_dtype is not None and isinstance( + item, (Variable, Value) + ): check_dtype( item.dtype, 'element of shape', @@ -250,7 +254,9 @@ def check_shape( ', '.join(expected_tensor_dtype) ), ) - if expected_tensor_dtype is not None and isinstance(shape, Variable): + if expected_tensor_dtype is not None and isinstance( + shape, (Variable, Value) + ): check_dtype(shape.dtype, 'shape', expected_tensor_dtype, op_name) diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 496ec9965d0cf..551fa2336e8d1 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -1112,7 +1112,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): low, high, shape, dtype, _current_expected_place() ) elif in_pir_mode(): - check_type(shape, 'shape', (list, tuple, paddle.pir.Value), 'randint') + check_shape(shape, 'randint') check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint') if paddle.utils._contain_var(shape): shape = paddle.utils.get_int_tensor_list( diff --git a/test/legacy_test/test_randint_op.py b/test/legacy_test/test_randint_op.py index 0558d7129fbe7..746138c138016 100644 --- a/test/legacy_test/test_randint_op.py +++ b/test/legacy_test/test_randint_op.py @@ -18,9 +18,8 @@ from op_test import OpTest import paddle -from paddle import base from paddle.base import core -from paddle.static import Program, program_guard +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -54,8 +53,11 @@ def verify_output(self, outs): class TestRandintOpError(unittest.TestCase): + @test_with_pir_api def test_errors(self): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): self.assertRaises(TypeError, paddle.randint, 5, shape=np.array([2])) self.assertRaises(TypeError, paddle.randint, 5, dtype='float32') self.assertRaises(ValueError, paddle.randint, 5, 5) @@ -67,14 +69,6 @@ def test_errors(self): TypeError, paddle.randint, 5, shape=[shape_tensor] ) - def test_pir_error(self): - with paddle.pir_utils.IrGuard(): - self.assertRaises(TypeError, paddle.randint, 5, shape=np.array([2])) - self.assertRaises(TypeError, paddle.randint, 5, dtype='float32') - self.assertRaises(ValueError, paddle.randint, 5, 5) - self.assertRaises(ValueError, paddle.randint, -5) - self.assertRaises(TypeError, paddle.randint, 5, shape=['2']) - class TestRandintOp_attr_tensorlist(OpTest): def setUp(self): @@ -125,7 +119,9 @@ def verify_output(self, outs): # Test python API class TestRandintAPI(unittest.TestCase): def test_api(self): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): # results are from [0, 5). out1 = paddle.randint(5) # shape is a list and dtype is 'int32' @@ -229,17 +225,20 @@ def test_dygraph(self): self.assertEqual(x.shape, []) paddle.enable_static() + @test_with_pir_api def test_static(self): - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.randint(-10, 10, []) # Test compile shape - self.assertEqual(x.shape, ()) + self.assertEqual(tuple(x.shape), ()) # Test runtime shape - exe = base.Executor() + exe = paddle.static.Executor() result = exe.run(fetch_list=[x]) - self.assertEqual(result[0].shape, ()) + self.assertEqual(tuple(result[0].shape), ()) paddle.enable_static()