Skip to content

Commit

Permalink
Update test_pixel_unshuffle.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BrilliantYuKaimin committed Mar 19, 2022
1 parent 8a259c0 commit b28157e
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,27 @@
import paddle.fluid.core as core
import paddle.fluid as fluid

paddle.enable_static()

def pixel_unshuffle_np(x, down_factor, data_format="NCHW"):
if data_format == "NCHW":
n, c, h, w = x.shape
new_shape = (n, c, h / down_factor, down_factor,
w / down_factor, down_factor)
new_shape = (n, c, h // down_factor, down_factor,
w // down_factor, down_factor)
npresult = np.reshape(x, new_shape)
npresult = npresult.transpose(0, 1, 3, 5, 2, 4)
oshape = [n, c * down_factor * down_factor, h / down_factor,
w / down_factor]
oshape = [n, c * down_factor * down_factor, h // down_factor,
w // down_factor]
npresult = np.reshape(npresult, oshape)
return npresult
else:
n, h, w, c = x.shape
new_shape = (n, h / down_factor, down_factor,
w / down_factor, down_factor, c)
new_shape = (n, h // down_factor, down_factor,
w // down_factor, down_factor, c)
npresult = np.reshape(x, new_shape)
npresult = npresult.transpose(0, 1, 3, 5, 2, 4)
oshape = [n, h / down_factor,
w / down_factor, c * down_factor * down_factor]
oshape = [n, h // down_factor,
w // down_factor, c * down_factor * down_factor]
npresult = np.reshape(npresult, oshape)
return npresult

Expand Down

0 comments on commit b28157e

Please sign in to comment.