Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transpose convolution matching convolution with given kernel parameters #2577

Closed
mathisgerdes opened this issue Nov 2, 2022 · 0 comments
Closed
Assignees

Comments

@mathisgerdes
Copy link
Contributor

In the current implementation, it appears to be quite challenging to compute the transpose convolution that actually corresponds to the transpose of a convolution with given kernel weights, at least when padding='CIRCULAR'.

In most cases, one can already achieve this by manually changing the kernel weights (flipping spacial dimensions and swapping channel axes for the kernel weights). This would arguably solve the issue. However, if padding='CIRCULAR', this doesn't work since incompatible alignment conventions for the kernel are used.

I suppose this is a somewhat unusual situation since it only applies if parameters are specified externally.

Motivation

Consider n-dimensional convolution. Omitting channel and batch dimensions, convolution is a linear operation and can generally be written as $A = WB$. (For simplicity, $B$ can be thought of as the flattened input, in which case $W$ is a matrix.)
The transpose operation should then satisfy $\langle A, WB \rangle = \langle W^T A, B \rangle$.

For jax.lax.conv_transpose this is true if the parameter transpose_kernel is set to True.
In code, that corresponds to:

IC, OC = 2, 3  # channels
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)

b = jax.random.normal(k1, (1, 30, 30, IC))
w = jax.random.normal(k2, (5, 5, IC, OC))
a = jax.random.normal(k3, (1, 26, 26, OC))

dim_num = ('NHWC', 'HWIO', 'NHWC')
w_b = jax.lax.conv_general_dilated(b, w, (1, 1), 'VALID', dimension_numbers=dim_num)

wt_a = jax.lax.conv_transpose(a, w, (1, 1), 'VALID', dimension_numbers=dim_num,
                              transpose_kernel=True)
# this corresponds to the scalar product
np.isclose(np.sum(a * w_b), np.sum(wt_a * b))  # True

This should ideally be reproducible with the ConvTranspose layer.
In this case, it is possible to manually modify the kernel weights:

conv = nn.Conv(OC, (5, 5), (1, 1), 'VALID', use_bias=False)
conv_tr = nn.ConvTranspose(IC, (5, 5), (1, 1), 'VALID', use_bias=False)

params = conv.init(k2, b)
# manually transpose kernel
kernel = params['params']['kernel']
kernel = jnp.flip(kernel, (0, 1)).swapaxes(-2, -1)
params_tr = {'params': {'kernel': kernel}}

w_b = conv.apply(params, b)
wt_a = conv_tr.apply(params_tr, a)
np.isclose(np.sum(a * w_b), np.sum(wt_a * b))  # True

However, with padding='CIRCULAR' this manual workaround fails:

b = jax.random.normal(k1, (1, 30, 30, IC))
w = jax.random.normal(k2, (5, 5, IC, OC))
a = jax.random.normal(k3, (1, 15, 15, OC))

conv = nn.Conv(OC, (5, 5), (2, 2), 'CIRCULAR', use_bias=False)
conv_tr = nn.ConvTranspose(IC, (5, 5), (2, 2), 'CIRCULAR', use_bias=False)

params = conv.init(k2, b)
# manually transpose kernel
kernel = params['params']['kernel']
kernel = jnp.flip(kernel, (0, 1)).swapaxes(-2, -1)
params_tr = {'params': {'kernel': kernel}}

w_b = conv.apply(params, b)
wt_a = conv_tr.apply(params_tr, a)
np.isclose(np.sum(a * w_b), np.sum(wt_a * b))  # False

The reason is the convention chosen here:

# Divide the padding equally between left and right. The choice to put
# "+1" on the left (and not on the right) represents a convention for
# aligning even-sized kernels.
total_pad = [
    ((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
]

If switched to (size_diff // 2, (size_diff + 1) // 2), it would work.

Options

  • Add a keyword argument transpose_kernel, which is passed to jax.lax.conv_transpose, and which also conditionally switches the convention for padding='CIRCULAR' (i.e. keeping the old convention when false, to be backward compatible).
  • Instead of this keyword argument, could just change the convention in ConvTranspose that is used for padding='CIRCULAR'. Then transpose_kernel can always be accomplished manually. However, this would change the network output with previously trained parameters.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants