Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pixel_shuffle #260

Merged
merged 7 commits into from
Dec 31, 2020
Merged

add pixel_shuffle #260

merged 7 commits into from
Dec 31, 2020

Conversation

CarloLucibello
Copy link
Member

No description provided.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks good to me, but I think Julia's column-major behavior means we get the transposed result.

test/misc.jl Show resolved Hide resolved
@CarloLucibello
Copy link
Member Author

CarloLucibello commented Dec 31, 2020

There is an inconsistency with pytorch implementation:

In [21]: x=torch.arange(1,16*2+1).reshape((1,8,2,2))

In [22]: nn.PixelShuffle(2)(x)
Out[22]: 
tensor([[[[ 1,  5,  2,  6],
          [ 9, 13, 10, 14],
          [ 3,  7,  4,  8],
          [11, 15, 12, 16]],

         [[17, 21, 18, 22],
          [25, 29, 26, 30],
          [19, 23, 20, 24],
          [27, 31, 28, 32]]]])

while here we have

julia> x = reshape(1:32, (2,2,8,1))

julia> pixel_shuffle(x, 2)
4×4×2×1 Array{Int64,4}:
[:, :, 1, 1] =     # expected transpose of pytorch's first channel
 1   9  2  10
 5  13  6  14
 3  11  4  12
 7  15  8  16

[:, :, 2, 1] =    # expected transpose of pytorch's second channel
 17  25  18  26
 21  29  22  30
 19  27  20  28
 23  31  24  32

I think the difference is irrelevant for the application, but we should try to be consistent.
Now I'm not sure that we can express this operation with permutedims

@DhairyaLGandhi
Copy link
Member

Is the phase shift being handled appropriately?

@CarloLucibello
Copy link
Member Author

I think not. I found a reference a good reference in the README here
https://github.com/atriumlts/subpixel
I'll implement it later

@CarloLucibello
Copy link
Member Author

this is a TensorFlow implementation

def _phase_shift(I, r):
    # Helper function with main phase shift operation
    bsize, a, b, c = I.get_shape().as_list()
    X = tf.reshape(I, (bsize, a, b, r, r))
    X = tf.transpose(X, (0, 1, 2, 4, 3))  # bsize, a, b, 1, 1
    X = tf.split(1, a, X)  # a, [bsize, b, r, r]
    X = tf.concat(2, [tf.squeeze(x) for x in X])  # bsize, b, a*r, r
    X = tf.split(1, b, X)  # b, [bsize, a*r, r]
    X = tf.concat(2, [tf.squeeze(x) for x in X])  #
    bsize, a*r, b*r
    return tf.reshape(X, (bsize, a*r, b*r, 1))

def PS(X, r, color=False):
  # Main OP that you can arbitrarily use in you tensorflow code
  if color:
    Xc = tf.split(3, 3, X)
    X = tf.concat(3, [_phase_shift(x, r) for x in Xc])
  else:
    X = _phase_shift(X, r)
  return X

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Dec 31, 2020

Looks like #112 already implements this

too much indexing and mutations, it is not AD friendly

@CarloLucibello
Copy link
Member Author

Pytorch's PR, doesn't look dissimilar from the code here
https://github.com/pytorch/pytorch/pull/338/files
Arbitrary dim implementation:
https://github.com/pytorch/pytorch/pull/6340/files

@CarloLucibello
Copy link
Member Author

the fix was easy. I'll try to generalize to arbitrary dimensions

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Dec 31, 2020

This works in arbitrary d, but it hits a weird zygote bug

function pixel_shuffle(x::AbstractArray, r::Integer)    
    @assert ndims(x) > 2
    d = ndims(x) - 2
    sizein = size(x)[1:d]
    cin, n = size(x, d+1), size(x, d+2) 
    cout = cin ÷ r^d
    x = reshape(x, sizein..., fill(r, d)..., cout, n)
    perm = [d+1:2d 1:d]' |> vec  # = [d+1, 1, d+2, 2, ..., 2d, d]
    x = permutedims(x, (perm..., 2d+1, 2d+2))
    return reshape(x, ((r .* sizein)..., cout, n))
end

@CarloLucibello
Copy link
Member Author

@DhairyaLGandhi do you have suggestions on how to modify the above code to work around the bug?

@darsnack
Copy link
Member

Could you use ntuple instead of fill?

@CarloLucibello
Copy link
Member Author

Could you use ntuple instead of fill?

yes, it works!

@darsnack
Copy link
Member

I think matching PyTorch's output for consistency even if it doesn't affect the application would be good. AFAICT, PyTorch matches the paper.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Dec 31, 2020

Definitely matching the output would be agreeable. Should fix before reviewing.

@CarloLucibello
Copy link
Member Author

It matches pytorch's implementation now
(the two hardcoded tests are from pytorch outputs).
This is ready to go, and once it is in I'll add the corresponding structural layer in FLux

@CarloLucibello CarloLucibello merged commit c08258b into master Dec 31, 2020
@CarloLucibello CarloLucibello deleted the cl/pixel branch June 15, 2023 17:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants