You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the current implementation, it appears to be quite challenging to compute the transpose convolution that actually corresponds to the transpose of a convolution with given kernel weights, at least when padding='CIRCULAR'.
In most cases, one can already achieve this by manually changing the kernel weights (flipping spacial dimensions and swapping channel axes for the kernel weights). This would arguably solve the issue. However, if padding='CIRCULAR', this doesn't work since incompatible alignment conventions for the kernel are used.
I suppose this is a somewhat unusual situation since it only applies if parameters are specified externally.
Motivation
Consider n-dimensional convolution. Omitting channel and batch dimensions, convolution is a linear operation and can generally be written as $A = WB$. (For simplicity, $B$ can be thought of as the flattened input, in which case $W$ is a matrix.)
The transpose operation should then satisfy $\langle A, WB \rangle = \langle W^T A, B \rangle$.
For jax.lax.conv_transpose this is true if the parameter transpose_kernel is set to True.
In code, that corresponds to:
# Divide the padding equally between left and right. The choice to put# "+1" on the left (and not on the right) represents a convention for# aligning even-sized kernels.total_pad= [
((size_diff+1) //2, size_diff//2) forsize_diffinsize_diffs
]
If switched to (size_diff // 2, (size_diff + 1) // 2), it would work.
Options
Add a keyword argument transpose_kernel, which is passed to jax.lax.conv_transpose, and which also conditionally switches the convention for padding='CIRCULAR' (i.e. keeping the old convention when false, to be backward compatible).
Instead of this keyword argument, could just change the convention in ConvTranspose that is used for padding='CIRCULAR'. Then transpose_kernel can always be accomplished manually. However, this would change the network output with previously trained parameters.
The text was updated successfully, but these errors were encountered:
In the current implementation, it appears to be quite challenging to compute the transpose convolution that actually corresponds to the transpose of a convolution with given kernel weights, at least when
padding='CIRCULAR'
.In most cases, one can already achieve this by manually changing the kernel weights (flipping spacial dimensions and swapping channel axes for the kernel weights). This would arguably solve the issue. However, if
padding='CIRCULAR'
, this doesn't work since incompatible alignment conventions for the kernel are used.I suppose this is a somewhat unusual situation since it only applies if parameters are specified externally.
Motivation
Consider n-dimensional convolution. Omitting channel and batch dimensions, convolution is a linear operation and can generally be written as$A = WB$ . (For simplicity, $B$ can be thought of as the flattened input, in which case $W$ is a matrix.)$\langle A, WB \rangle = \langle W^T A, B \rangle$ .
The transpose operation should then satisfy
For
jax.lax.conv_transpose
this is true if the parametertranspose_kernel
is set toTrue
.In code, that corresponds to:
This should ideally be reproducible with the
ConvTranspose
layer.In this case, it is possible to manually modify the kernel weights:
However, with
padding='CIRCULAR'
this manual workaround fails:The reason is the convention chosen here:
If switched to
(size_diff // 2, (size_diff + 1) // 2)
, it would work.Options
transpose_kernel
, which is passed tojax.lax.conv_transpose
, and which also conditionally switches the convention forpadding='CIRCULAR'
(i.e. keeping the old convention when false, to be backward compatible).ConvTranspose
that is used forpadding='CIRCULAR'
. Thentranspose_kernel
can always be accomplished manually. However, this would change the network output with previously trained parameters.The text was updated successfully, but these errors were encountered: