-
Notifications
You must be signed in to change notification settings - Fork 660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding "count_include_pad" argument to flax.linen.pooling.avg_pool #2451
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2451 +/- ##
==========================================
- Coverage 79.66% 78.83% -0.84%
==========================================
Files 49 49
Lines 4982 5070 +88
==========================================
+ Hits 3969 3997 +28
- Misses 1013 1073 +60 ☔ View full report in Codecov by Sentry. |
Hey @dslisleedh, thanks for creating this PR! |
To @cgarciae, Sorry, I tested my self but didn't share result in PR. Here's result.
This is first time for me to create PR, so tell me anything if I missed something. Thank you. |
Just noticed there aren't any tests for |
I found PoolTest class from ./tests/linen/test_linen.py
When I ran this code with the edits from this PR there were no problems.
|
@dslisleedh can you create one of more tests under |
…iv_shape for it raises error when there's no batch dimension
To @cgarciae I add some codes to TestPool and here's code I tested.
When I tested with the previous code, an error occurred in the non-batch avg_pool, so I corrected the PR. Thanks for telling me to test it for sure. Below is the test result with the modified code.
Thank you. |
Awesome @dslisleedh! Can you commit changes to the tests? |
…iv_shape for it raises error when there's no batch dimension
To @cgarciae Sure :) |
@dslisleedh can you add this test? None of the other tests used @parameterized.parameters(
{'count_include_pad': True},
{'count_include_pad': False})
def test_avg_pool_padding_same(self, count_include_pad):
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
pool = lambda x: nn.avg_pool(x, (2, 2), padding="SAME", count_include_pad=count_include_pad)
y = pool(x)
if count_include_pad:
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1))
else:
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1))
np.testing.assert_allclose(y, expected_y) |
@cgarciae Oh, I forgot that. Thank you. and here is result of your code.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great @dslisleedh, thanks for going through with this!
What does this PR do?
Now version's flax.linen.pooling.avg_pool average window_sum result include padded tokens. I add argument whether to include padded tokens or not
Checklist
issues
documentation guidelines.
(No quality testing = no merge!)