Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transpose_kernel argument to ConvTranspose. #2578

Merged
merged 1 commit into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ class ConvTranspose(Module):
for details.
kernel_init: initializer for the convolutional kernel.
bias_init: initializer for the bias.
transpose_kernel: if True flips spatial axes and swaps the input/output
channel axes of the kernel.
"""
features: int
kernel_size: Union[int, Tuple[int, ...]]
Expand All @@ -596,6 +598,7 @@ class ConvTranspose(Module):
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
transpose_kernel: bool = False

@compact
def __call__(self, inputs: Array) -> Array:
Expand Down Expand Up @@ -638,7 +641,10 @@ def __call__(self, inputs: Array) -> Array:
strides = self.strides or (1,) * (inputs.ndim - 2)

in_features = jnp.shape(inputs)[-1]
kernel_shape = kernel_size + (in_features, self.features)
if self.transpose_kernel:
kernel_shape = kernel_size + (self.features, in_features)
else:
kernel_shape = kernel_size + (in_features, self.features)

if self.mask is not None and self.mask.shape != kernel_shape:
raise ValueError('Mask needs to have the same shape as weights. '
Expand Down Expand Up @@ -669,6 +675,7 @@ def __call__(self, inputs: Array) -> Array:
strides,
padding_lax,
rhs_dilation=self.kernel_dilation,
transpose_kernel=self.transpose_kernel,
precision=self.precision)

if self.padding == 'CIRCULAR':
Expand All @@ -691,12 +698,20 @@ def __call__(self, inputs: Array) -> Array:
-(y_dim - x_dim) % (2 * x_dim)
for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims)
]
# Divide the padding equaly between left and right. The choice to put
# "+1" on the left (and not on the right) represents a convention for
# aligning even-sized kernels.
total_pad = [
((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
]
if self.transpose_kernel:
# If the kernel is transposed, the "+1" is put on the right to
# mirror the regular convolution. If the same kernel parameters are used
# as for Conv, this layer then computes the proper transpose convolution.
total_pad = [
(size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs
]
else:
# Divide the padding equally between left and right. The choice to put
# "+1" on the left (and not on the right) represents a convention for
# aligning even-sized kernels.
total_pad = [
((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
]
y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)])
# Wrap the result periodically around each spatial dimension,
# one by one.
Expand Down
17 changes: 17 additions & 0 deletions tests/linen/linen_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,23 @@ def test_circular_conv_transpose_2d_custom_bias(self):
correct_ans = np.expand_dims(correct_ans, (0, 3))
np.testing.assert_allclose(y, correct_ans)

@parameterized.product(
use_bias=(True, False))
def test_transpose_kernel_conv_transpose(self, use_bias):
rng = dict(params=random.PRNGKey(0))
x = jnp.ones((1, 15, 15, 3))
conv_module = nn.ConvTranspose(
features=4,
use_bias=use_bias,
strides=(2, 2),
kernel_size=(6, 6),
padding='CIRCULAR',
transpose_kernel=True,
)
y, initial_params = conv_module.init_with_output(rng, x)
self.assertEqual(initial_params['params']['kernel'].shape, (6, 6, 4, 3))
self.assertEqual(y.shape, (1, 30, 30, 4))

@parameterized.product(
module=(nn.Conv, nn.ConvLocal)
)
Expand Down