Skip to content

Commit

Permalink
Add transpose_kernel argument to ConvTranspose.
Browse files Browse the repository at this point in the history
  • Loading branch information
mathisgerdes committed Dec 13, 2022
1 parent 5bd4cd2 commit bbf7856
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
29 changes: 22 additions & 7 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ class ConvTranspose(Module):
for details.
kernel_init: initializer for the convolutional kernel.
bias_init: initializer for the bias.
transpose_kernel: if True flips spatial axes and swaps the input/output
channel axes of the kernel.
"""
features: int
kernel_size: Union[int, Tuple[int, ...]]
Expand All @@ -596,6 +598,7 @@ class ConvTranspose(Module):
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
transpose_kernel: bool = False

@compact
def __call__(self, inputs: Array) -> Array:
Expand Down Expand Up @@ -638,7 +641,10 @@ def __call__(self, inputs: Array) -> Array:
strides = self.strides or (1,) * (inputs.ndim - 2)

in_features = jnp.shape(inputs)[-1]
kernel_shape = kernel_size + (in_features, self.features)
if self.transpose_kernel:
kernel_shape = kernel_size + (self.features, in_features)
else:
kernel_shape = kernel_size + (in_features, self.features)

if self.mask is not None and self.mask.shape != kernel_shape:
raise ValueError('Mask needs to have the same shape as weights. '
Expand Down Expand Up @@ -669,6 +675,7 @@ def __call__(self, inputs: Array) -> Array:
strides,
padding_lax,
rhs_dilation=self.kernel_dilation,
transpose_kernel=self.transpose_kernel,
precision=self.precision)

if self.padding == 'CIRCULAR':
Expand All @@ -691,12 +698,20 @@ def __call__(self, inputs: Array) -> Array:
-(y_dim - x_dim) % (2 * x_dim)
for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims)
]
# Divide the padding equaly between left and right. The choice to put
# "+1" on the left (and not on the right) represents a convention for
# aligning even-sized kernels.
total_pad = [
((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
]
if self.transpose_kernel:
# If the kernel is transposed, the "+1" is put on the right to
# mirror the regular convolution. If the same kernel parameters are used
# as for Conv, this layer then computes the proper transpose convolution.
total_pad = [
(size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs
]
else:
# Divide the padding equally between left and right. The choice to put
# "+1" on the left (and not on the right) represents a convention for
# aligning even-sized kernels.
total_pad = [
((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
]
y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)])
# Wrap the result periodically around each spatial dimension,
# one by one.
Expand Down
17 changes: 17 additions & 0 deletions tests/linen/linen_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,23 @@ def test_circular_conv_transpose_2d_custom_bias(self):
correct_ans = np.expand_dims(correct_ans, (0, 3))
np.testing.assert_allclose(y, correct_ans)

@parameterized.product(
use_bias=(True, False))
def test_transpose_kernel_conv_transpose(self, use_bias):
rng = dict(params=random.PRNGKey(0))
x = jnp.ones((1, 15, 15, 3))
conv_module = nn.ConvTranspose(
features=4,
use_bias=use_bias,
strides=(2, 2),
kernel_size=(6, 6),
padding='CIRCULAR',
transpose_kernel=True,
)
y, initial_params = conv_module.init_with_output(rng, x)
self.assertEqual(initial_params['params']['kernel'].shape, (6, 6, 4, 3))
self.assertEqual(y.shape, (1, 30, 30, 4))

@parameterized.product(
module=(nn.Conv, nn.ConvLocal)
)
Expand Down

0 comments on commit bbf7856

Please sign in to comment.