Skip to content

Commit

Permalink
Migrate pixel_shuffle python layer to functor (#7745)
Browse files Browse the repository at this point in the history
* Migrate pixel_shuffle python layer to functor

* Format

* Format

* Refine, docstr
  • Loading branch information
lixiang007666 authored Mar 16, 2022
1 parent 7b82266 commit c4112a3
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 28 deletions.
2 changes: 2 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ Operators for neural networks
FusedBatchNorm3d,
FusedMLP,

.. autofunction:: oneflow.nn.modules.pixelshuffle.PixelShufflev2

.. autofunction:: oneflow.nn.parallel.DistributedDataParallel

.. currentmodule:: oneflow.nn.utils
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2037,3 +2037,7 @@
- name: "einsum"
signature: "Tensor (String equation, TensorTuple operands) => EinSum"
bind_python: True

- name: "pixel_shuffle"
signature: "Tensor (Tensor input, Int64 h_upscale_factor, Int64 w_upscale_factor) => PixelShuffle"
bind_python: True
34 changes: 34 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,39 @@ class LayerNormAffineFunctor {
std::shared_ptr<OpExpr> op_;
};

class PixelShuffleFunctor {
public:
PixelShuffleFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& h_upscale_factor,
const int64_t& w_upscale_factor) const {
CHECK_OR_RETURN(x->ndim() == 4) << "Only Accept 4D Tensor";
const int64_t batch = x->shape()->At(0);
const int64_t channel = x->shape()->At(1);
const int64_t height = x->shape()->At(2);
const int64_t width = x->shape()->At(3);
std::shared_ptr<one::Tensor> out;
CHECK_OR_RETURN(channel % (h_upscale_factor * w_upscale_factor) == 0)
<< "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or "
"(h_upscale_factor * w_upscale_factor)";
const int64_t new_c = static_cast<int>(channel / (h_upscale_factor * w_upscale_factor));
std::vector<int32_t> permute_vec = {0, 1, 4, 2, 5, 3};
std::vector<int64_t> reshape_vec_1 = {batch, new_c, h_upscale_factor * w_upscale_factor, height,
width};
Shape reshape_1(DimVector(reshape_vec_1.begin(), reshape_vec_1.end()));
std::vector<int64_t> reshape_vec_2 = {batch, new_c, h_upscale_factor, w_upscale_factor,
height, width};
Shape reshape_2(DimVector(reshape_vec_2.begin(), reshape_vec_2.end()));
std::vector<int64_t> reshape_vec_3 = {batch, new_c, height * h_upscale_factor,
width * w_upscale_factor};
Shape reshape_3(DimVector(reshape_vec_3.begin(), reshape_vec_3.end()));
out = JUST(Reshape(x, reshape_1));
out = JUST(Reshape(out, reshape_2));
out = JUST(Permute(out, permute_vec));
out = JUST(Reshape(out, reshape_3));
return out;
}
};

class PoolNDFunctor {
public:
PoolNDFunctor() = default;
Expand Down Expand Up @@ -2455,6 +2488,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::PadFunctor>("Pad");
m.add_functor<impl::DropoutFunctor>("Dropout");
m.add_functor<impl::DropoutGradFunctor>("DropoutGrad");
m.add_functor<impl::PixelShuffleFunctor>("PixelShuffle");
m.add_functor<impl::Avgpool1DFunctor>("Avgpool1D");
m.add_functor<impl::Avgpool2DFunctor>("Avgpool2D");
m.add_functor<impl::Avgpool3DFunctor>("Avgpool3D");
Expand Down
31 changes: 3 additions & 28 deletions python/oneflow/nn/modules/pixelshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
from typing import Optional

import oneflow as flow
from oneflow.framework.tensor import Tensor
from oneflow.nn.module import Module

Expand Down Expand Up @@ -108,35 +109,9 @@ def __init__(
self.w_upscale_factor = w_upscale_factor

def forward(self, input: Tensor) -> Tensor:
assert len(input.shape) == 4, "Only Accept 4D Tensor"
(_batch, _channel, _height, _width) = input.shape
assert (
_channel % (self.h_upscale_factor * self.w_upscale_factor) == 0
), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or (h_upscale_factor * w_upscale_factor)"
_new_c = int(_channel / (self.h_upscale_factor * self.w_upscale_factor))
out = input.reshape(
_batch,
_new_c,
self.h_upscale_factor * self.w_upscale_factor,
_height,
_width,
)
out = out.reshape(
_batch,
_new_c,
self.h_upscale_factor,
self.w_upscale_factor,
_height,
_width,
)
out = out.permute(0, 1, 4, 2, 5, 3)
out = out.reshape(
_batch,
_new_c,
_height * self.h_upscale_factor,
_width * self.w_upscale_factor,
return flow._C.pixel_shuffle(
input, self.h_upscale_factor, self.w_upscale_factor
)
return out

def extra_repr(self) -> str:
return f"w_upscale_factor={self.w_upscale_factor}, h_upscale_factor={self.h_upscale_factor}"
Expand Down

0 comments on commit c4112a3

Please sign in to comment.