-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New padding options for Conv
and ConvTranspose
#658
New padding options for Conv
and ConvTranspose
#658
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution! This is excellent; thank you for implementing so much.
My main two comments are:
- Can you add some more documentation to the docstrings? (E.g. what does it mean to use
padding_mode == "CIRCULAR"
inConvTranspose
?) - Can you add some tests? I think I spot a few cases where the wrong thing is going to happen at the moment (I've commented on a couple of them below.)
equinox/nn/_conv.py
Outdated
@@ -80,6 +116,8 @@ def __init__( | |||
concatenating the results along the output channel dimension. | |||
`in_channels` must be divisible by `groups`. | |||
- `use_bias`: Whether to add on a bias after the convolution. | |||
- `padding_mode`: One of the following strings specifying the padding values. | |||
`'ZEROS'` (default), `'REFLECT'`, `'REPLICATE'`, or `'CIRCULAR'`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what each of these options does?
rhs_shape = tuple( | ||
d * (k - 1) + 1 for k, d in zip(self.kernel_size, self.dilation) | ||
) | ||
padding = lax.padtype_to_pads( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is public API, unfortunately. (It's not listed here, anyway.) Maybe we can ask to add it to the public API for JAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you raise a new issue in JAX? I think your proposal is much more likely to be accepted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! We'll see what they say.
equinox/nn/_conv.py
Outdated
"'REFLECT' or 'REPLICATE' padding mode is not implemented" | ||
) | ||
elif padding_mode == "CIRCULAR": | ||
if np.any(output_padding): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I prefer avoiding implicit-cast-to-bool. (Here for the int->bool conversion on each element.) Can this be if any(x !=0 for x in output_padding)
?
Some more documentation and tests have been added according to your suggestion. Thanks a lot! |
Looks like the pre-commit hooks are failing. (See CONTRIBUTING.md for setting these up locally.) Modulo that detail, this PR LGTM! Thank you for putting it together, it'll be a great addition to have. Let me know once you've got the tests passing and I'll merge this PR. |
Some type hints have been fixed. There seems to be something wrong with the test. It always reports |
Rebase on top of the latest |
And merged! Thank you for the contribution. this is great to have. This will appear in the next release of Equinox :) |
Nice! Also thanks for your time! |
* add new padding options for Conv and ConvTranspose * Update _conv.py * Add tests for the padding of `Conv` and `ConvTranspose` * Fix some type hints * Fix the type of padding_t
According to #638, some changes are made to the padding in
Conv
andConvTranspose
.The
padding
argument now supports three string inputs, "VALID", "SAME", or "SAME_LOWER", similar tojax.lax.conv_general_dilated
.Furthermore, there is a new argument
padding_mode
which accepts string inputs including "ZEROS" (default), "REFLECT", "REPLICATE", and "CIRCULAR", similar to PyTorch.To make the "SAME" or "SAME_LOWER" padding work for
ConvTranspose
, the definition ofoutput_padding
is slightly changed to makeoutput_size = (input_size + output_padding) / stride
in these cases.Only "ZEROS" and "CIRCULAR" are implemented in
ConvTranspose
. "CIRCULAR" further requiresoutput_padding == 0
and
padding == "SAME"
or"SAME_LOWER"
to simplify it, which is the typical case when people use "CIRCULAR" padding.