Skip to content

Commit

Permalink
update place
Browse files Browse the repository at this point in the history
  • Loading branch information
ooooo-create committed Aug 1, 2024
1 parent eb9fd62 commit eda5d0f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 21 deletions.
3 changes: 1 addition & 2 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3760,9 +3760,8 @@ void ShuffleBatchInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
seed_out->share_dims(seed);
seed_out->share_lod(seed);
seed_out->set_dtype(x.dtype());
seed_out->set_dtype(seed.dtype());
shuffle_idx->set_dims(phi::make_ddim({-1}));
shuffle_idx->set_dtype(x.dtype());
}

void SequenceMaskInferMeta(const MetaTensor& x,
Expand Down
46 changes: 27 additions & 19 deletions test/legacy_test/test_shuffle_batch_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,44 +89,52 @@ def get_shape(self):

class TestShuffleBatchAPI(unittest.TestCase):
def setUp(self):
self.places = [paddle.CPUPlace()]
if not os.name == 'nt' and paddle.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
paddle.enable_static()

def tearDown(self):
paddle.disable_static()

def test_seed_without_tensor(self):
def api_run(seed):
def api_run(seed, place=paddle.CPUPlace()):
main_prog, startup_prog = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32')
out = paddle.incubate.layers.shuffle_batch(x, seed=seed)
exe = paddle.static.Executor()
exe = paddle.static.Executor(place=place)
feed = {'x': np.random.random((10, 4)).astype('float32')}
exe.run(startup_prog)
_ = exe.run(main_prog, feed=feed, fetch_list=[out])

api_run(seed=None)
api_run(seed=1)
for place in self.places:
api_run(None, place=place)
api_run(1, place=place)

def test_seed_with_tensor(self):
main_prog, startup_prog = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32')
seed = paddle.static.data(name='seed', shape=[1], dtype='int64')
out = paddle.incubate.layers.shuffle_batch(x, seed=seed)
exe = paddle.static.Executor()
feed = {
'x': np.random.random((10, 4)).astype('float32'),
'seed': np.array([1]).astype('int64'),
}
exe.run(startup_prog)
_ = exe.run(main_prog, feed=feed, fetch_list=[out])
def api_run(place=paddle.CPUPlace()):
main_prog, startup_prog = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32')
seed = paddle.static.data(name='seed', shape=[1], dtype='int64')
out = paddle.incubate.layers.shuffle_batch(x, seed=seed)
exe = paddle.static.Executor(place=place)
feed = {
'x': np.random.random((10, 4)).astype('float32'),
'seed': np.array([1]).astype('int64'),
}
exe.run(startup_prog)
_ = exe.run(main_prog, feed=feed, fetch_list=[out])

for place in self.places:
api_run(place=place)


if __name__ == '__main__':
Expand Down

0 comments on commit eda5d0f

Please sign in to comment.