-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add transpose_kernel argument to ConvTranspose. #2578
Conversation
Codecov Report
@@ Coverage Diff @@
## main #2578 +/- ##
==========================================
+ Coverage 79.50% 79.52% +0.01%
==========================================
Files 49 49
Lines 5206 5211 +5
==========================================
+ Hits 4139 4144 +5
Misses 1067 1067
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
flax/linen/linear.py
Outdated
@@ -688,12 +695,17 @@ def __call__(self, inputs: Array) -> Array: | |||
-(y_dim - x_dim) % (2 * x_dim) | |||
for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims) | |||
] | |||
# Divide the padding equaly between left and right. The choice to put | |||
# Divide the padding equally between left and right. The choice to put |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could move this comment to the else
branch and have a similar one for the positive branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, the comments there could be clearer. How about this modified version, with clearer comments for both cases?
flax/linen/linear.py
Outdated
(size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs | ||
] | ||
else: | ||
# This was the convention chosen previously and is kept here for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for improving this! I have an additional comment, since else
this is the default case can we keep the original explanation here? On the positive branch, can you explain why the +1
is now on the right? (e.g. "since the kernel is being transposed we add +1
on the right side instead to mirror the regular version".
18befe7
to
1efb0bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mathisgerdes for doing this! Exposing transpose_kernel
is a great idea.
@mathisgerdes @cgarciae this PR has build failures so I have removed the "pull ready" label. Could you please fix those first? |
1efb0bf
to
bbf7856
Compare
I believe those were problems with the previous commit, not any changes here. I've rebased to the latest commit. |
Is there still something to be done here, @cgarciae, @andsteing? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay. LGTM.
Thanks for doing this!
This is a proposal for #2577, adding a
transpose_kernel
option toConvTranspose
such that the previous behavior is maintained by default.transpose_kernel
parameter toConvTranspose
which is passed tojax.lax.conv_transpose
padding='CIRCULAR'
: To maintain the previous convention ofConvTranspose
, the alignment of the kernel is changed depending ontranspose_kernel
.