Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support asymmetric padding for "same" padding with even kernel size #2676

Open
dfsfdfse opened this issue Jan 9, 2025 · 1 comment
Open
Labels
bug Something isn't working enhancement Enhance existing features

Comments

@dfsfdfse
Copy link

dfsfdfse commented Jan 9, 2025

Describe the bug
PaddingConfig::Same doesn't return the same input shape when kernel_size = [2, 2]

Additional context

use burn::backend::Wgpu;
use burn::backend::wgpu::WgpuDevice;
use burn::tensor::Distribution;

let device = WgpuDevice::BestAvailable;
let data: Tensor<Wgpu, 4> = Tensor::random([1, 3, 16, 16], Distribution::Default, &device);
let pool = MaxPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same).init();
let v = pool.forward(data);
println!("{:?}", v.shape());
@laggui laggui added bug Something isn't working enhancement Enhance existing features labels Jan 9, 2025
@laggui
Copy link
Member

laggui commented Jan 9, 2025

Thanks for flagging!

For even kernel sizes the "same" padding scheme actually requires asymmetric padding in order to produce the same output shape.

We only support symmetric padding right now, might be a good idea to add a check in the meantime 😅

If you try a kernel size of 2, 4, 6 or any other even number with same padding I believe you will observe the same behavior.

@laggui laggui changed the title PaddingConfig::Same doesn't return the same input shape when kernel_size = [2, 2] Support asymmetric padding for "same" padding with even kernel size Jan 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement Enhance existing features
Projects
None yet
Development

No branches or pull requests

2 participants