diff --git a/python/paddle/base/reader.py b/python/paddle/base/reader.py index abca7f527db9a..4926cf4b63ab6 100644 --- a/python/paddle/base/reader.py +++ b/python/paddle/base/reader.py @@ -23,6 +23,7 @@ import paddle from paddle.base.framework import _set_expected_place +from paddle.pir.core import datatype_to_vartype from . import core from .data_feeder import BatchedTensorProvider, DataFeeder @@ -35,6 +36,7 @@ default_main_program, default_startup_program, in_dygraph_mode, + in_pir_mode, program_guard, ) from .layers.io import ( @@ -840,10 +842,16 @@ def _init_iterable(self): self._wait_thread_ends() self._var_names = [v.name for v in self._feed_list] self._shapes = [v.shape for v in self._feed_list] - self._dtypes = [v.dtype for v in self._feed_list] - self._need_check_feed = [ - v.desc.need_check_feed() for v in self._feed_list - ] + if in_pir_mode(): + self._dtypes = [ + datatype_to_vartype[v.dtype] for v in self._feed_list + ] + self._need_check_feed = [False for v in self._feed_list] + else: + self._dtypes = [v.dtype for v in self._feed_list] + self._need_check_feed = [ + v.desc.need_check_feed() for v in self._feed_list + ] self._queue = core.init_lod_tensor_blocking_queue( core.Variable(), self._capacity, self._keep_order ) @@ -874,7 +882,10 @@ def _init_non_iterable(self): ranks.append(len(feed_data.shape)) shapes.append(feed_data.shape) lod_levels.append(feed_data.lod_level) - need_check_feed.append(int(feed_data.desc.need_check_feed())) + if in_pir_mode(): + need_check_feed.append(0) + else: + need_check_feed.append(int(feed_data.desc.need_check_feed())) queue_name = data_loader_unique_name_generator( 'lod_tensor_blocking_queue' diff --git a/test/deprecated/legacy_test/test_py_reader_combination.py b/test/legacy_test/test_py_reader_combination.py similarity index 99% rename from test/deprecated/legacy_test/test_py_reader_combination.py rename to test/legacy_test/test_py_reader_combination.py index df62b0b61ccf7..f685fca746118 100644 --- a/test/deprecated/legacy_test/test_py_reader_combination.py +++ b/test/legacy_test/test_py_reader_combination.py @@ -119,4 +119,5 @@ def setUp(self): if __name__ == '__main__': + paddle.enable_static() unittest.main()