Skip to content

Commit

Permalink
【Fix PIR Unittest No.74】Fix some test cast in PIR mode (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…64350)

* fix_pir_api 4

* refine
  • Loading branch information
wanghuancoder authored and co63oc committed May 18, 2024
1 parent 270b0a2 commit 3cf1d63
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
21 changes: 16 additions & 5 deletions python/paddle/base/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import paddle
from paddle.base.framework import _set_expected_place
from paddle.pir.core import datatype_to_vartype

from . import core
from .data_feeder import BatchedTensorProvider, DataFeeder
Expand All @@ -35,6 +36,7 @@
default_main_program,
default_startup_program,
in_dygraph_mode,
in_pir_mode,
program_guard,
)
from .layers.io import (
Expand Down Expand Up @@ -840,10 +842,16 @@ def _init_iterable(self):
self._wait_thread_ends()
self._var_names = [v.name for v in self._feed_list]
self._shapes = [v.shape for v in self._feed_list]
self._dtypes = [v.dtype for v in self._feed_list]
self._need_check_feed = [
v.desc.need_check_feed() for v in self._feed_list
]
if in_pir_mode():
self._dtypes = [
datatype_to_vartype[v.dtype] for v in self._feed_list
]
self._need_check_feed = [False for v in self._feed_list]
else:
self._dtypes = [v.dtype for v in self._feed_list]
self._need_check_feed = [
v.desc.need_check_feed() for v in self._feed_list
]
self._queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._capacity, self._keep_order
)
Expand Down Expand Up @@ -874,7 +882,10 @@ def _init_non_iterable(self):
ranks.append(len(feed_data.shape))
shapes.append(feed_data.shape)
lod_levels.append(feed_data.lod_level)
need_check_feed.append(int(feed_data.desc.need_check_feed()))
if in_pir_mode():
need_check_feed.append(0)
else:
need_check_feed.append(int(feed_data.desc.need_check_feed()))

queue_name = data_loader_unique_name_generator(
'lod_tensor_blocking_queue'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,5 @@ def setUp(self):


if __name__ == '__main__':
paddle.enable_static()
unittest.main()

0 comments on commit 3cf1d63

Please sign in to comment.