Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New padding options for Conv and ConvTranspose #658

Merged
merged 5 commits into from
Feb 20, 2024

Conversation

ChenAo-Phys
Copy link
Contributor

According to #638, some changes are made to the padding in Conv and ConvTranspose.

The padding argument now supports three string inputs, "VALID", "SAME", or "SAME_LOWER", similar to jax.lax.conv_general_dilated.

Furthermore, there is a new argument padding_mode which accepts string inputs including "ZEROS" (default), "REFLECT", "REPLICATE", and "CIRCULAR", similar to PyTorch.

To make the "SAME" or "SAME_LOWER" padding work for ConvTranspose, the definition of output_padding is slightly changed to make output_size = (input_size + output_padding) / stride in these cases.

Only "ZEROS" and "CIRCULAR" are implemented in ConvTranspose. "CIRCULAR" further requires output_padding == 0
and padding == "SAME" or "SAME_LOWER" to simplify it, which is the typical case when people use "CIRCULAR" padding.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution! This is excellent; thank you for implementing so much.

My main two comments are:

  1. Can you add some more documentation to the docstrings? (E.g. what does it mean to use padding_mode == "CIRCULAR" in ConvTranspose?)
  2. Can you add some tests? I think I spot a few cases where the wrong thing is going to happen at the moment (I've commented on a couple of them below.)

@@ -80,6 +116,8 @@ def __init__(
concatenating the results along the output channel dimension.
`in_channels` must be divisible by `groups`.
- `use_bias`: Whether to add on a bias after the convolution.
- `padding_mode`: One of the following strings specifying the padding values.
`'ZEROS'` (default), `'REFLECT'`, `'REPLICATE'`, or `'CIRCULAR'`.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what each of these options does?

rhs_shape = tuple(
d * (k - 1) + 1 for k, d in zip(self.kernel_size, self.dilation)
)
padding = lax.padtype_to_pads(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is public API, unfortunately. (It's not listed here, anyway.) Maybe we can ask to add it to the public API for JAX?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you raise a new issue in JAX? I think your proposal is much more likely to be accepted

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! We'll see what they say.

equinox/nn/_conv.py Show resolved Hide resolved
"'REFLECT' or 'REPLICATE' padding mode is not implemented"
)
elif padding_mode == "CIRCULAR":
if np.any(output_padding):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I prefer avoiding implicit-cast-to-bool. (Here for the int->bool conversion on each element.) Can this be if any(x !=0 for x in output_padding)?

@patrick-kidger patrick-kidger changed the base branch from main to dev February 16, 2024 22:12
@ChenAo-Phys
Copy link
Contributor Author

Thank you for the contribution! This is excellent; thank you for implementing so much.

My main two comments are:

  1. Can you add some more documentation to the docstrings? (E.g. what does it mean to use padding_mode == "CIRCULAR" in ConvTranspose?)
  2. Can you add some tests? I think I spot a few cases where the wrong thing is going to happen at the moment (I've commented on a couple of them below.)

Some more documentation and tests have been added according to your suggestion. Thanks a lot!

@patrick-kidger
Copy link
Owner

Looks like the pre-commit hooks are failing. (See CONTRIBUTING.md for setting these up locally.)

Modulo that detail, this PR LGTM! Thank you for putting it together, it'll be a great addition to have. Let me know once you've got the tests passing and I'll merge this PR.

@ChenAo-Phys
Copy link
Contributor Author

Some type hints have been fixed.

There seems to be something wrong with the test. It always reports
FAILED tests/test_errors.py::test_tracetime - Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.

@patrick-kidger
Copy link
Owner

Rebase on top of the latest dev, I fixed this one a few days ago :)

@patrick-kidger patrick-kidger merged commit 23d983e into patrick-kidger:dev Feb 20, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

And merged! Thank you for the contribution. this is great to have. This will appear in the next release of Equinox :)

@ChenAo-Phys
Copy link
Contributor Author

And merged! Thank you for the contribution. this is great to have. This will appear in the next release of Equinox :)

Nice! Also thanks for your time!

@ChenAo-Phys ChenAo-Phys deleted the ChenAo-Phys-patch-1 branch February 20, 2024 15:46
@patrick-kidger patrick-kidger mentioned this pull request Apr 14, 2024
patrick-kidger pushed a commit that referenced this pull request Apr 14, 2024
* add new padding options for Conv and ConvTranspose

* Update _conv.py

* Add tests for the padding of `Conv` and `ConvTranspose`

* Fix some type hints

* Fix the type of padding_t
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants