Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv Transpose doesn't exactly transpose #728

Closed
lockwo opened this issue May 17, 2024 · 4 comments
Closed

Conv Transpose doesn't exactly transpose #728

lockwo opened this issue May 17, 2024 · 4 comments
Labels
documentation Improvements or additions to documentation

Comments

@lockwo
Copy link
Contributor

lockwo commented May 17, 2024

When checking if ConvTranspose is actually computing the transpose operation, it seems to be failing. I've tried different weight matrix shapes, but I'm uncertain as to why this is failing:

from jax import numpy as jnp
import jax
import equinox as eqx
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape=(2, 6, 6))
a = jax.random.uniform(jax.random.PRNGKey(24), shape=x.shape)

cnn = eqx.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding='SAME', key=key, use_bias=False)
cnn_t = eqx.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding='SAME', key=key, use_bias=False)
#cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.transpose(cnn.weight, (1, 0, 3, 2)))

fx = cnn(x)
ga = cnn_t(a)
dot_product_1 = jnp.sum(a * fx)
dot_product_2 = jnp.sum(x * ga)

assert jnp.allclose(dot_product_1, dot_product_2)

The following TF code succeeds as expected:

import tensorflow as tf
import numpy as np

x = tf.random.normal((1, 5, 5, 2))
a = tf.random.normal((1, 5, 5, 2))

filter_tensor = tf.random.normal((3, 3, 2, 2))

stride = 1
padding = 'SAME'
fx = tf.nn.conv2d(x, filter_tensor, strides=[1, stride, stride, 1], padding=padding)

u_height, u_width, u_depth = a.shape[1], a.shape[2], a.shape[3]
output_shape = [1, u_height, u_width, u_depth]

ga = tf.nn.conv2d_transpose(a, filter_tensor, output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding)

dot_product_1 = tf.math.reduce_sum(a * fx)
dot_product_2 = tf.math.reduce_sum(x * ga)

assert np.allclose(dot_product_1, dot_product_2)

What's the best way to get the actual transpose of a convolution in equinox?

@lockwo
Copy link
Contributor Author

lockwo commented May 17, 2024

Very similar thread in flax: google/flax#2577. However, their solution of manually transposing doesn't work: cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.flip(jnp.array(cnn.weight), (0, 1)).swapaxes(-2, -1)) still fails. And the approach of transpose_kernel=True doesn't work since conv_general_dilated is used not conv_transpose

@lockwo
Copy link
Contributor Author

lockwo commented May 17, 2024

Forgot that flax had weights reversed. cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.flip(jnp.array(cnn.weight), (2, 3)).swapaxes(0, 1)) Actually does work. Maybe the tooling for this should be documented somewhere? Or we could add it as a flag like flax did google/flax#2578?

@patrick-kidger
Copy link
Owner

I think the flip + swapaxes combination would make sense to document somewhere.

I probably wouldn't add it as a flag. Transposed convolutions are already very complicated, I'd prefer not to add yet another complexity to them!

FWIW let's not forget that jax.linear_transpose exists. Back when I first wrote ConvTranspose I considered using jax.linear_transpose(Conv(...)) as its implementation. Sometimes I wonder if that's what I should have done!

@patrick-kidger patrick-kidger added the documentation Improvements or additions to documentation label May 19, 2024
@lockwo
Copy link
Contributor Author

lockwo commented May 19, 2024

Makes sense, these signatures could quickly become overwhelming. I will add it to the documentation.

I definitely did not forget about jax.linear_transpose, because I didn't even know it existed! So thanks for also brining that to my attention

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants