diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 25d4b468ff5..4e104e30438 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -98,6 +98,8 @@ Operators for neural networks FusedBatchNorm3d, FusedMLP, +.. autofunction:: oneflow.nn.modules.pixelshuffle.PixelShufflev2 + .. autofunction:: oneflow.nn.parallel.DistributedDataParallel .. currentmodule:: oneflow.nn.utils diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 2d92b8638b1..659b28d68e0 100755 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2037,3 +2037,7 @@ - name: "einsum" signature: "Tensor (String equation, TensorTuple operands) => EinSum" bind_python: True + +- name: "pixel_shuffle" + signature: "Tensor (Tensor input, Int64 h_upscale_factor, Int64 w_upscale_factor) => PixelShuffle" + bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 2ab1782faed..f76744e265c 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -411,6 +411,39 @@ class LayerNormAffineFunctor { std::shared_ptr op_; }; +class PixelShuffleFunctor { + public: + PixelShuffleFunctor() {} + Maybe operator()(const std::shared_ptr& x, const int64_t& h_upscale_factor, + const int64_t& w_upscale_factor) const { + CHECK_OR_RETURN(x->ndim() == 4) << "Only Accept 4D Tensor"; + const int64_t batch = x->shape()->At(0); + const int64_t channel = x->shape()->At(1); + const int64_t height = x->shape()->At(2); + const int64_t width = x->shape()->At(3); + std::shared_ptr out; + CHECK_OR_RETURN(channel % (h_upscale_factor * w_upscale_factor) == 0) + << "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or " + "(h_upscale_factor * w_upscale_factor)"; + const int64_t new_c = static_cast(channel / (h_upscale_factor * w_upscale_factor)); + std::vector permute_vec = {0, 1, 4, 2, 5, 3}; + std::vector reshape_vec_1 = {batch, new_c, h_upscale_factor * w_upscale_factor, height, + width}; + Shape reshape_1(DimVector(reshape_vec_1.begin(), reshape_vec_1.end())); + std::vector reshape_vec_2 = {batch, new_c, h_upscale_factor, w_upscale_factor, + height, width}; + Shape reshape_2(DimVector(reshape_vec_2.begin(), reshape_vec_2.end())); + std::vector reshape_vec_3 = {batch, new_c, height * h_upscale_factor, + width * w_upscale_factor}; + Shape reshape_3(DimVector(reshape_vec_3.begin(), reshape_vec_3.end())); + out = JUST(Reshape(x, reshape_1)); + out = JUST(Reshape(out, reshape_2)); + out = JUST(Permute(out, permute_vec)); + out = JUST(Reshape(out, reshape_3)); + return out; + } +}; + class PoolNDFunctor { public: PoolNDFunctor() = default; @@ -2455,6 +2488,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Pad"); m.add_functor("Dropout"); m.add_functor("DropoutGrad"); + m.add_functor("PixelShuffle"); m.add_functor("Avgpool1D"); m.add_functor("Avgpool2D"); m.add_functor("Avgpool3D"); diff --git a/python/oneflow/nn/modules/pixelshuffle.py b/python/oneflow/nn/modules/pixelshuffle.py index d5c2662f9ad..2ad9c98b571 100644 --- a/python/oneflow/nn/modules/pixelshuffle.py +++ b/python/oneflow/nn/modules/pixelshuffle.py @@ -15,6 +15,7 @@ """ from typing import Optional +import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn.module import Module @@ -108,35 +109,9 @@ def __init__( self.w_upscale_factor = w_upscale_factor def forward(self, input: Tensor) -> Tensor: - assert len(input.shape) == 4, "Only Accept 4D Tensor" - (_batch, _channel, _height, _width) = input.shape - assert ( - _channel % (self.h_upscale_factor * self.w_upscale_factor) == 0 - ), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or (h_upscale_factor * w_upscale_factor)" - _new_c = int(_channel / (self.h_upscale_factor * self.w_upscale_factor)) - out = input.reshape( - _batch, - _new_c, - self.h_upscale_factor * self.w_upscale_factor, - _height, - _width, - ) - out = out.reshape( - _batch, - _new_c, - self.h_upscale_factor, - self.w_upscale_factor, - _height, - _width, - ) - out = out.permute(0, 1, 4, 2, 5, 3) - out = out.reshape( - _batch, - _new_c, - _height * self.h_upscale_factor, - _width * self.w_upscale_factor, + return flow._C.pixel_shuffle( + input, self.h_upscale_factor, self.w_upscale_factor ) - return out def extra_repr(self) -> str: return f"w_upscale_factor={self.w_upscale_factor}, h_upscale_factor={self.h_upscale_factor}"