-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add pixel_shuffle #260
add pixel_shuffle #260
Conversation
145fbc7
to
9070bd4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation looks good to me, but I think Julia's column-major behavior means we get the transposed result.
There is an inconsistency with pytorch implementation: In [21]: x=torch.arange(1,16*2+1).reshape((1,8,2,2))
In [22]: nn.PixelShuffle(2)(x)
Out[22]:
tensor([[[[ 1, 5, 2, 6],
[ 9, 13, 10, 14],
[ 3, 7, 4, 8],
[11, 15, 12, 16]],
[[17, 21, 18, 22],
[25, 29, 26, 30],
[19, 23, 20, 24],
[27, 31, 28, 32]]]]) while here we have julia> x = reshape(1:32, (2,2,8,1))
julia> pixel_shuffle(x, 2)
4×4×2×1 Array{Int64,4}:
[:, :, 1, 1] = # expected transpose of pytorch's first channel
1 9 2 10
5 13 6 14
3 11 4 12
7 15 8 16
[:, :, 2, 1] = # expected transpose of pytorch's second channel
17 25 18 26
21 29 22 30
19 27 20 28
23 31 24 32 I think the difference is irrelevant for the application, but we should try to be consistent. |
Is the phase shift being handled appropriately? |
I think not. I found a reference a good reference in the README here |
this is a TensorFlow implementation def _phase_shift(I, r):
# Helper function with main phase shift operation
bsize, a, b, c = I.get_shape().as_list()
X = tf.reshape(I, (bsize, a, b, r, r))
X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1
X = tf.split(1, a, X) # a, [bsize, b, r, r]
X = tf.concat(2, [tf.squeeze(x) for x in X]) # bsize, b, a*r, r
X = tf.split(1, b, X) # b, [bsize, a*r, r]
X = tf.concat(2, [tf.squeeze(x) for x in X]) #
bsize, a*r, b*r
return tf.reshape(X, (bsize, a*r, b*r, 1))
def PS(X, r, color=False):
# Main OP that you can arbitrarily use in you tensorflow code
if color:
Xc = tf.split(3, 3, X)
X = tf.concat(3, [_phase_shift(x, r) for x in Xc])
else:
X = _phase_shift(X, r)
return X |
too much indexing and mutations, it is not AD friendly |
Pytorch's PR, doesn't look dissimilar from the code here |
the fix was easy. I'll try to generalize to arbitrary dimensions |
This works in arbitrary function pixel_shuffle(x::AbstractArray, r::Integer)
@assert ndims(x) > 2
d = ndims(x) - 2
sizein = size(x)[1:d]
cin, n = size(x, d+1), size(x, d+2)
cout = cin ÷ r^d
x = reshape(x, sizein..., fill(r, d)..., cout, n)
perm = [d+1:2d 1:d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d]
x = permutedims(x, (perm..., 2d+1, 2d+2))
return reshape(x, ((r .* sizein)..., cout, n))
end |
@DhairyaLGandhi do you have suggestions on how to modify the above code to work around the bug? |
Could you use |
yes, it works! |
I think matching PyTorch's output for consistency even if it doesn't affect the application would be good. AFAICT, PyTorch matches the paper. |
Definitely matching the output would be agreeable. Should fix before reviewing. |
It matches pytorch's implementation now |
No description provided.