Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic version of pixel_shuffle - 1d, 2d, ..., nd #6340

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,8 +1645,9 @@ def multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=Tr


def pixel_shuffle(input, upscale_factor):
r"""Rearranges elements in a tensor of shape :math:`[*, C*r^2, H, W]` to a
tensor of shape :math:`[C, H*r, W*r]`.
r"""Rearranges elements in a Tensor of shape :math:`(N, C, d_{1}, d_{2}, ..., d_{n})` to a
tensor of shape :math:`(N, C/(r^n), d_{1}*r, d_{2}*r, ..., d_{n}*r)`.
Where :math:`n` is the dimensionality of the data.

See :class:`~torch.nn.PixelShuffle` for details.

Expand All @@ -1655,25 +1656,40 @@ def pixel_shuffle(input, upscale_factor):
upscale_factor (int): factor to increase spatial resolution by

Examples::
# 1D example
>>> input = torch.Tensor(1, 4, 8)
>>> output = F.pixel_shuffle(input, 2)
>>> print(output.size())
torch.Size([1, 2, 16])

# 2D example
>>> input = torch.Tensor(1, 9, 8, 8)
>>> output = F.pixel_shuffle(input, 3)
>>> print(output.size())
torch.Size([1, 1, 24, 24])

>>> ps = nn.PixelShuffle(3)
>>> input = torch.Tensor(1, 9, 4, 4)
>>> output = ps(input)
# 3D example
>>> input = torch.Tensor(1, 8, 16, 16, 16)
>>> output = F.pixel_shuffle(input, 2)
>>> print(output.size())
torch.Size([1, 1, 12, 12])
torch.Size([1, 1, 32, 32, 32])
"""
batch_size, channels, in_height, in_width = input.size()
channels //= upscale_factor ** 2
input_size = list(input.size())
dimensionality = len(input_size) - 2

out_height = in_height * upscale_factor
out_width = in_width * upscale_factor
input_size[1] //= (upscale_factor ** dimensionality)
output_size = [dim * upscale_factor for dim in input_size[2:]]

input_view = input.contiguous().view(
batch_size, channels, upscale_factor, upscale_factor,
in_height, in_width)
input_size[0], input_size[1],
*(([upscale_factor] * dimensionality) + input_size[2:])
)

indicies = list(range(2, 2 + 2 * dimensionality))

This comment was marked as off-topic.

indicies = indicies[1::2] + indicies[0::2]

shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
return shuffle_out.view(batch_size, channels, out_height, out_width)
shuffle_out = input_view.permute(0, 1, *(indicies[::-1])).contiguous()
return shuffle_out.view(input_size[0], input_size[1], *output_size)


def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
Expand Down
34 changes: 28 additions & 6 deletions torch/nn/modules/pixelshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@


class PixelShuffle(Module):
r"""Rearranges elements in a Tensor of shape :math:`(*, r^2C, H, W)` to a
tensor of shape :math:`(C, rH, rW)`.
r"""Rearranges elements in a Tensor of shape :math:`(N, C, d_{1}, d_{2}, ..., d_{n})` to a
tensor of shape :math:`(N, C/(r^n), d_{1}*r, d_{2}*r, ..., d_{n}*r)`.
Where :math:`n` is the dimensionality of the data.

This is useful for implementing efficient sub-pixel convolution
with a stride of :math:`1/r`.

Input Tensor must have at least 3 dimensions, e.g. :math:`(N, C, d_{1})` for 1D data,
but Tensors with any number of dimensions after :math:`(N, C, ...)` (where N is mini-batch size,
and C is channels) are supported.

Look at the paper:
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
by Shi et. al (2016) for more details
Expand All @@ -17,16 +22,33 @@ class PixelShuffle(Module):
upscale_factor (int): factor to increase spatial resolution by

Shape:
- Input: :math:`(N, C * \text{upscale_factor}^2, H, W)`
- Output: :math:`(N, C, H * \text{upscale_factor}, W * \text{upscale_factor})`
- Input: :math:`(N, C, d_{1}, d_{2}, ..., d_{n})`
- Output: :math:`(N, C/(r^n), d_{1}*r, d_{2}*r, ..., d_{n}*r)`
Where :math:`n` is the dimensionality of the data, e.g. :math:`n-1` for 1D audio,
:math:`n=2` for 2D images, etc.

Examples::

# 1D example
>>> ps = nn.PixelShuffle(2)
>>> input = torch.Tensor(1, 4, 8)
>>> output = ps(input)
>>> print(output.size())
torch.Size([1, 2, 16])

# 2D example
>>> ps = nn.PixelShuffle(3)
>>> input = torch.Tensor(1, 9, 4, 4)
>>> input = torch.Tensor(1, 9, 8, 8)
>>> output = ps(input)
>>> print(output.size())
torch.Size([1, 1, 24, 24])

# 3D example
>>> ps = nn.PixelShuffle(2)
>>> input = torch.Tensor(1, 8, 16, 16, 16)
>>> output = ps(input)
>>> print(output.size())
torch.Size([1, 1, 12, 12])
torch.Size([1, 1, 32, 32, 32])

.. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
https://arxiv.org/abs/1609.05158
Expand Down