From d3f2931df82a50fb84ff510af64f63c8f563efdc Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Mon, 20 May 2024 10:54:46 -0600 Subject: [PATCH] add cnn transpose weight flipping --- equinox/nn/_conv.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/equinox/nn/_conv.py b/equinox/nn/_conv.py index b9ddbb88..db0bb351 100644 --- a/equinox/nn/_conv.py +++ b/equinox/nn/_conv.py @@ -487,6 +487,20 @@ def __init__( See [these animations](https://github.com/vdumoulin/conv_arithmetic/blob/af6f818b0bb396c26da79899554682a8a499101d/README.md#transposed-convolution-animations) and [this report](https://arxiv.org/abs/1603.07285) for a nice reference. + !!! faq "FAQ" + + If you need to exactly transpose a convolutional layer, i.e. not just create an + operation with similar inductive biases but compute the actual linear transpose + of a specific CNN you can reshape the weights of the forward convolution + via the following: + + ```python + cnn = eqx.Conv(...) + cnn_t = eqx.ConvTranspose(...) + cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.flip(cnn.weight, + axis=tuple(range(2, cnn.weight.ndim))).swapaxes(0, 1)) + ``` + !!! warning `padding_mode='CIRCULAR'` is only implemented for `output_padding=0` and