Skip to content

Commit

Permalink
pir support pixel unshuffle op (PaddlePaddle#57521)
Browse files Browse the repository at this point in the history
  • Loading branch information
phlrain authored and jiahy0825 committed Oct 16, 2023
1 parent 039665a commit 9daac57
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 137 deletions.
105 changes: 0 additions & 105 deletions paddle/fluid/operators/pixel_unshuffle_op.cc

This file was deleted.

9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,15 @@
kernel :
func : pixel_shuffle_grad

- backward_op : pixel_unshuffle_grad
forward : pixel_unshuffle (Tensor x, int downscale_factor=1, str data_format="NCHW") -> Tensor(out)
args : (Tensor out_grad, int downscale_factor, str data_format)
output : Tensor(x_grad)
infer_meta :
func : PixelUnshuffleGradInferMeta
kernel :
func : pixel_unshuffle_grad

- backward_op : poisson_grad
forward : poisson (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2103,6 +2103,13 @@
outputs :
out : Out

- op : pixel_unshuffle
backward : pixel_unshuffle_grad
inputs :
x : X
outputs :
out : Out

- op : poisson
inputs :
x : X
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,15 @@
func : pixel_shuffle
backward : pixel_shuffle_grad

- op : pixel_unshuffle
args : (Tensor x, int downscale_factor=1, str data_format="NCHW")
output : Tensor
infer_meta :
func : PixelUnshuffleInferMeta
kernel :
func : pixel_unshuffle
backward : pixel_unshuffle_grad

- op : poisson
args : (Tensor x)
output : Tensor
Expand Down
30 changes: 0 additions & 30 deletions paddle/phi/ops/compat/pixel_unshuffle_sig.cc

This file was deleted.

4 changes: 2 additions & 2 deletions test/legacy_test/test_pixel_unshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def pixel_unshuffle_np(x, down_factor, data_format="NCHW"):


def pixel_unshuffle_wrapper(x, downscale_factor, data_format):
return paddle._legacy_C_ops.pixel_unshuffle(
x, "downscale_factor", downscale_factor, "data_format", data_format
return paddle.nn.functional.pixel_unshuffle(
x, downscale_factor, data_format
)


Expand Down

0 comments on commit 9daac57

Please sign in to comment.