From a2b29749f61a0104618a6889d2080965646f5a67 Mon Sep 17 00:00:00 2001 From: PuQing Date: Sun, 24 Mar 2024 15:25:15 +0000 Subject: [PATCH 1/9] fix max_pool --- paddle/phi/api/yaml/backward.yaml | 8 ++++---- paddle/phi/api/yaml/ops.yaml | 4 ++-- paddle/phi/infermeta/backward.cc | 1 + paddle/phi/infermeta/backward.h | 1 + paddle/phi/infermeta/unary.cc | 4 +++- paddle/phi/infermeta/unary.h | 1 + paddle/phi/kernels/funcs/pooling.h | 13 +++++++------ paddle/phi/kernels/impl/pool_kernel_impl.h | 2 ++ paddle/phi/kernels/pool_kernel.h | 2 ++ paddle/phi/kernels/xpu/pool_kernel.cc | 2 ++ python/paddle/nn/functional/pooling.py | 6 +++--- test/legacy_test/test_pool1d_api.py | 17 +++++++++++++++++ test/legacy_test/test_pool2d_api.py | 17 +++++++++++++++++ test/legacy_test/test_pool3d_api.py | 17 +++++++++++++++++ 14 files changed, 79 insertions(+), 16 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index c53f81cad71f41..73f0a54092471d 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1543,8 +1543,8 @@ func : matrix_power_grad - backward_op : max_pool2d_with_index_grad - forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(mask) - args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool ceil_mode = false) output : Tensor(x_grad) infer_meta : func : MaxPoolWithIndexGradInferMeta @@ -1552,8 +1552,8 @@ func : max_pool2d_with_index_grad - backward_op : max_pool3d_with_index_grad - forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(mask) - args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool ceil_mode = false) output : Tensor(x_grad) infer_meta : func : MaxPoolWithIndexGradInferMeta diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 4759da3105e4c0..1003894ba3a461 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1877,7 +1877,7 @@ backward : matrix_power_grad - op : max_pool2d_with_index - args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false) + args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) output : Tensor(out), Tensor(mask) infer_meta : func : MaxPoolWithIndexInferMeta @@ -1886,7 +1886,7 @@ backward : max_pool2d_with_index_grad - op : max_pool3d_with_index - args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false) + args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) output : Tensor(out), Tensor(mask) infer_meta : func : MaxPoolWithIndexInferMeta diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a651346358034b..a856b566281ed0 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -710,6 +710,7 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode, MetaTensor* dx) { dx->share_meta(x); } diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 364a90d750077d..83fec1c799715b 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -339,6 +339,7 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode, MetaTensor* dx); void MeshgridGradInferMeta(const std::vector& inputs, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 64262af8885d9f..384943ea8437ed 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2373,6 +2373,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode, MetaTensor* out, MetaTensor* mask, MetaConfig config) { @@ -2431,7 +2432,8 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, funcs::MaxPoolOutputSize(static_cast(x_dims[i + 2]), kernel_size_[i], paddings_[i], - strides[i])); + strides[i], + ceil_mode)); } } } diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 3314545faa1858..75ac0d27ad964b 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -364,6 +364,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode, MetaTensor* out, MetaTensor* mask, MetaConfig config = MetaConfig()); diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index 3e91175e8a3920..cf1428397802f7 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -502,17 +502,18 @@ inline int PoolOutputSize(int input_size, return output_size; } -inline int MaxPoolOutputSize(int input_size, - int filter_size, - int padding, - int stride) { +inline int MaxPoolOutputSize( + int input_size, int filter_size, int padding, int stride, bool ceil_mode) { PADDLE_ENFORCE_NE( stride, 0, phi::errors::InvalidArgument( "The stride of MaxPool shall not be 0, but received %d.", stride)); - int output_size = (input_size - filter_size + 2 * padding) / stride + 1; - return output_size; + if (ceil_mode) { + return (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; + } else { + return (input_size - filter_size + 2 * padding) / stride + 1; + } } template diff --git a/paddle/phi/kernels/impl/pool_kernel_impl.h b/paddle/phi/kernels/impl/pool_kernel_impl.h index 50a5195e771e8f..53d9d08f26445d 100644 --- a/paddle/phi/kernels/impl/pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_kernel_impl.h @@ -260,6 +260,7 @@ void MaxPool2dWithIndexKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode UNUSED, DenseTensor* out, DenseTensor* mask) { MaxPoolWithIndexRawKernel(ctx, @@ -309,6 +310,7 @@ void MaxPool3dWithIndexKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode UNUSED, DenseTensor* out, DenseTensor* mask) { MaxPoolWithIndexRawKernel(ctx, diff --git a/paddle/phi/kernels/pool_kernel.h b/paddle/phi/kernels/pool_kernel.h index e958d62d8c225a..d62b2868c7c619 100644 --- a/paddle/phi/kernels/pool_kernel.h +++ b/paddle/phi/kernels/pool_kernel.h @@ -60,6 +60,7 @@ void MaxPool2dWithIndexKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode, DenseTensor* out, DenseTensor* mask); @@ -101,6 +102,7 @@ void MaxPool3dWithIndexKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode, DenseTensor* out, DenseTensor* mask); diff --git a/paddle/phi/kernels/xpu/pool_kernel.cc b/paddle/phi/kernels/xpu/pool_kernel.cc index 466adade072c7a..cd8a62fb0a180a 100644 --- a/paddle/phi/kernels/xpu/pool_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/pool_kernel.h" +#include "paddle/common/macros.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/pooling.h" @@ -308,6 +309,7 @@ void MaxPool2dWithIndexKernel(const Context& ctx, const std::vector& paddings_t, bool global_pooling, bool adaptive, + bool ceil_mode UNUSED, DenseTensor* out, DenseTensor* mask) { using XPUType = typename XPUTypeTrait::Type; diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index dc79776afe90d6..6f367d7f71582a 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -629,7 +629,7 @@ def max_pool1d( if in_dynamic_or_pir_mode(): if return_mask: pool_out = _C_ops.max_pool2d_with_index( - x, kernel_size, stride, padding, False, False + x, kernel_size, stride, padding, False, False, ceil_mode ) return ( (squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2])) @@ -1201,7 +1201,7 @@ def max_pool2d( if in_dynamic_or_pir_mode(): if return_mask: output = _C_ops.max_pool2d_with_index( - x, kernel_size, stride, padding, False, False + x, kernel_size, stride, padding, False, False, ceil_mode ) return output if return_mask else output[0] else: @@ -1366,7 +1366,7 @@ def max_pool3d( if in_dynamic_or_pir_mode(): if return_mask: output = _C_ops.max_pool3d_with_index( - x, kernel_size, stride, padding, False, False + x, kernel_size, stride, padding, False, False, ceil_mode ) return output if return_mask else output[0] else: diff --git a/test/legacy_test/test_pool1d_api.py b/test/legacy_test/test_pool1d_api.py index 6fac04f468ebec..dcb897973e7868 100644 --- a/test/legacy_test/test_pool1d_api.py +++ b/test/legacy_test/test_pool1d_api.py @@ -296,6 +296,22 @@ def check_avg_dygraph_padding_same(self, place): np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + def check_max_pool_return_mask_ceil(self, place): + with base.dygraph.guard(place): + input_np = np.random.random([1, 3, 6]).astype("float32") + input = paddle.to_tensor(input_np) + result, _ = F.max_pool1d( + input, kernel_size=5, ceil_mode=True, return_mask=True + ) + result_np = max_pool1D_forward_naive( + input_np, + ksize=[5], + strides=[5], + paddings=[0], + ceil_mode=True, + ) + np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + def test_pool1d(self): for place in self.places: self.check_max_dygraph_results(place) @@ -306,6 +322,7 @@ def test_pool1d(self): self.check_avg_dygraph_padding_same(place) self.check_max_dygraph_return_index_results(place) self.check_avg_static_results_fp16(place) + self.check_max_pool_return_mask_ceil(place) class TestPool1DError_API(unittest.TestCase): diff --git a/test/legacy_test/test_pool2d_api.py b/test/legacy_test/test_pool2d_api.py index b371a45d49ffc7..0bd1cae2f7e0de 100644 --- a/test/legacy_test/test_pool2d_api.py +++ b/test/legacy_test/test_pool2d_api.py @@ -357,6 +357,22 @@ def check_avg_divisor(self, place): result = avg_pool2d_dg(input) np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + def check_max_pool_return_mask_ceil(self, place): + with base.dygraph.guard(place): + input_np = np.random.random([2, 3, 33, 33]).astype("float32") + input = paddle.to_tensor(input_np) + result, _ = max_pool2d( + input, kernel_size=5, ceil_mode=True, return_mask=True + ) + result_np = pool2D_forward_naive( + input_np, + ksize=[5, 5], + strides=[5, 5], + paddings=[0, 0], + ceil_mode=True, + ) + np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + def test_pool2d(self): for place in self.places: self.check_max_dygraph_results(place) @@ -368,6 +384,7 @@ def test_pool2d(self): self.check_max_dygraph_padding_results(place) self.check_max_dygraph_ceilmode_results(place) self.check_max_dygraph_nhwc_results(place) + self.check_max_pool_return_mask_ceil(place) @test_with_pir_api def test_pool2d_static(self): diff --git a/test/legacy_test/test_pool3d_api.py b/test/legacy_test/test_pool3d_api.py index 55c782faa5d27c..4e7f4bb032ca3e 100644 --- a/test/legacy_test/test_pool3d_api.py +++ b/test/legacy_test/test_pool3d_api.py @@ -356,6 +356,22 @@ def check_avg_divisor(self, place): ) np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + def check_max_pool_return_mask_ceil(self, place): + with base.dygraph.guard(place): + input_np = np.random.random([1, 2, 6, 33, 33]).astype("float32") + input = paddle.to_tensor(input_np) + result, _ = max_pool3d( + input, kernel_size=5, ceil_mode=True, return_mask=True + ) + result_np = pool3D_forward_naive( + input_np, + ksize=[5, 5, 5], + strides=[5, 5, 5], + paddings=[0, 0, 0], + ceil_mode=True, + ) + np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + def test_pool3d(self): paddle.enable_static() for place in self.places: @@ -368,6 +384,7 @@ def test_pool3d(self): self.check_avg_divisor(place) self.check_max_dygraph_ndhwc_results(place) self.check_max_dygraph_ceilmode_results(place) + self.check_max_pool_return_mask_ceil(place) @test_with_pir_api def test_static_fp16_gpu(self): From 3b188b559991b57f5088a04026922897e132ecf3 Mon Sep 17 00:00:00 2001 From: PuQing Date: Sun, 24 Mar 2024 15:26:31 +0000 Subject: [PATCH 2/9] fix stft --- python/paddle/signal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/signal.py b/python/paddle/signal.py index 8e64bc2e3400a6..5f49c88567b969 100644 --- a/python/paddle/signal.py +++ b/python/paddle/signal.py @@ -362,6 +362,7 @@ def stft( ), f'expected a 1D window tensor of size equal to win_length({win_length}), but got window with shape {window.shape}.' else: window = paddle.ones(shape=(win_length,), dtype=x.dtype) + window = paddle.to_tensor(window, place=x.place) if win_length < n_fft: pad_left = (n_fft - win_length) // 2 From a5a122705d49bd7191efb7dd5f2de7d9b6c3f819 Mon Sep 17 00:00:00 2001 From: PuQing Date: Mon, 25 Mar 2024 01:47:19 +0000 Subject: [PATCH 3/9] Add ceil_mode parameter to MaxPool2d and MaxPool3d functions --- paddle/phi/kernels/impl/pool_grad_kernel_impl.h | 3 +++ paddle/phi/kernels/pool_grad_kernel.h | 2 ++ paddle/phi/kernels/xpu/pool_grad_kernel.cc | 1 + test/legacy_test/test_pool_max_op.py | 6 ++++-- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h index c8a42e0265fb80..0b5b5df98351b6 100644 --- a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/common/ddim.h" +#include "paddle/common/macros.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/pool_grad_kernel.h" @@ -262,6 +263,7 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode UNUSED, DenseTensor* dx) { MaxPoolWithIndexGradRawKernel(ctx, x, @@ -317,6 +319,7 @@ void MaxPool3dWithIndexGradKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode UNUSED, DenseTensor* dx) { MaxPoolWithIndexGradRawKernel(ctx, x, diff --git a/paddle/phi/kernels/pool_grad_kernel.h b/paddle/phi/kernels/pool_grad_kernel.h index 2f813aa9dc050d..bdaf2cd172c080 100644 --- a/paddle/phi/kernels/pool_grad_kernel.h +++ b/paddle/phi/kernels/pool_grad_kernel.h @@ -96,6 +96,7 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode, DenseTensor* dx); template @@ -142,6 +143,7 @@ void MaxPool3dWithIndexGradKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool ceil_mode, DenseTensor* dx); template diff --git a/paddle/phi/kernels/xpu/pool_grad_kernel.cc b/paddle/phi/kernels/xpu/pool_grad_kernel.cc index b03be5dd9449cb..f2eb7532b3a55e 100644 --- a/paddle/phi/kernels/xpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_grad_kernel.cc @@ -386,6 +386,7 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, const std::vector& paddings_t, bool global_pooling, bool adaptive, + bool ceil_mode UNUSED, DenseTensor* dx) { using XPUType = typename XPUTypeTrait::Type; diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index f2186dee7c3399..42340b517b26b2 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -143,9 +143,10 @@ def max_pool3d_with_index_wapper( paddings=[], global_pooling=False, adaptive=False, + ceil_mode=False, ): return paddle._C_ops.max_pool3d_with_index( - x, kernel_size, strides, paddings, global_pooling, adaptive + x, kernel_size, strides, paddings, global_pooling, adaptive, ceil_mode ) @@ -342,9 +343,10 @@ def max_pool2d_with_index_wapper( paddings=[], global_pooling=False, adaptive=False, + ceil_mode=False, ): return paddle._C_ops.max_pool2d_with_index( - x, kernel_size, strides, paddings, global_pooling, adaptive + x, kernel_size, strides, paddings, global_pooling, adaptive, ceil_mode ) From 0c92d8ee069479de75ac3aef1ad5674b4a30fafa Mon Sep 17 00:00:00 2001 From: PuQing Date: Mon, 25 Mar 2024 02:36:55 +0000 Subject: [PATCH 4/9] Add test case for converting STFT result to real numbers --- test/legacy_test/test_stft_op.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/legacy_test/test_stft_op.py b/test/legacy_test/test_stft_op.py index 665714f3e939ea..2b2f8fd6a5b39c 100644 --- a/test/legacy_test/test_stft_op.py +++ b/test/legacy_test/test_stft_op.py @@ -89,5 +89,28 @@ def test_check_grad_normal(self): paddle.disable_static() +class TestStftOpReal(unittest.TestCase): + def test_as_real(self): + input = np.random.randn(4410) + x = paddle.to_tensor( + data=input, dtype='float32', place=paddle.CUDAPlace(0) + ) + n_fft = 400 + res0 = paddle.signal.stft(n_fft=n_fft, x=x) + res_a = paddle.as_real(res0) + + res_np = res0.numpy() + res_p = paddle.to_tensor(data=res_np) + res_b = paddle.as_real(res_p) + + np.testing.assert_allclose( + res0.numpy(), res_p.numpy(), rtol=1e-5, atol=1e-5 + ) + + np.testing.assert_allclose( + res_a.numpy(), res_b.numpy(), rtol=1e-5, atol=1e-5 + ) + + if __name__ == '__main__': unittest.main() From 02eb1dc336cf650ac3b5c5d2fdbf55f893f7fb02 Mon Sep 17 00:00:00 2001 From: PuQing Date: Wed, 27 Mar 2024 03:57:45 +0000 Subject: [PATCH 5/9] Fix adaptive pooling ceil_mode bug --- python/paddle/nn/functional/pooling.py | 9 ++++++--- test/legacy_test/test_stft_op.py | 4 +--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 6f367d7f71582a..1459e179d835b7 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -1816,7 +1816,7 @@ def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): x = unsqueeze(x, [2]) if in_dygraph_mode(): pool_out = _C_ops.max_pool2d_with_index( - x, pool_size, [1, 1], [0, 0], False, True + x, pool_size, [1, 1], [0, 0], False, True, False ) return ( (squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2])) @@ -1847,6 +1847,7 @@ def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): "pooling_type": 'max', "ksize": pool_size, "adaptive": True, + "ceil_mode": False, }, ) @@ -1910,7 +1911,7 @@ def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): output_size[1] = in_w if in_dygraph_mode(): pool_out = _C_ops.max_pool2d_with_index( - x, output_size, [1, 1], [0, 0], False, True + x, output_size, [1, 1], [0, 0], False, True, False ) return pool_out if return_mask else pool_out[0] else: @@ -1937,6 +1938,7 @@ def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): "pooling_type": 'max', "ksize": output_size, "adaptive": True, + "ceil_mode": False, }, ) return (pool_out, mask) if return_mask else pool_out @@ -2002,7 +2004,7 @@ def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): if in_dygraph_mode(): # By default, strides is [1,1,1] and paddings is [0, 0, 0] pool_out = _C_ops.max_pool3d_with_index( - x, output_size, [1, 1, 1], [0, 0, 0], False, True + x, output_size, [1, 1, 1], [0, 0, 0], False, True, False ) return pool_out if return_mask else pool_out[0] else: @@ -2029,6 +2031,7 @@ def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): "pooling_type": 'max', "ksize": output_size, "adaptive": True, + "ceil_mode": False, }, ) diff --git a/test/legacy_test/test_stft_op.py b/test/legacy_test/test_stft_op.py index 2b2f8fd6a5b39c..51573bf7af769f 100644 --- a/test/legacy_test/test_stft_op.py +++ b/test/legacy_test/test_stft_op.py @@ -92,9 +92,7 @@ def test_check_grad_normal(self): class TestStftOpReal(unittest.TestCase): def test_as_real(self): input = np.random.randn(4410) - x = paddle.to_tensor( - data=input, dtype='float32', place=paddle.CUDAPlace(0) - ) + x = paddle.to_tensor(data=input, dtype='float32') n_fft = 400 res0 = paddle.signal.stft(n_fft=n_fft, x=x) res_a = paddle.as_real(res0) From f5c8ad47af66393465b0223f01d1b9c48b5e95a5 Mon Sep 17 00:00:00 2001 From: PuQing Date: Sat, 6 Apr 2024 08:38:12 +0000 Subject: [PATCH 6/9] remove window to x device --- python/paddle/signal.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/signal.py b/python/paddle/signal.py index 5f49c88567b969..8e64bc2e3400a6 100644 --- a/python/paddle/signal.py +++ b/python/paddle/signal.py @@ -362,7 +362,6 @@ def stft( ), f'expected a 1D window tensor of size equal to win_length({win_length}), but got window with shape {window.shape}.' else: window = paddle.ones(shape=(win_length,), dtype=x.dtype) - window = paddle.to_tensor(window, place=x.place) if win_length < n_fft: pad_left = (n_fft - win_length) // 2 From c1bec1c05b743aaf057657f7c5ff7324e82f85ea Mon Sep 17 00:00:00 2001 From: PuQing Date: Sat, 6 Apr 2024 08:40:33 +0000 Subject: [PATCH 7/9] add shape check for pool op --- test/legacy_test/test_pool1d_api.py | 8 +++++++- test/legacy_test/test_pool2d_api.py | 8 +++++++- test/legacy_test/test_pool3d_api.py | 8 +++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/legacy_test/test_pool1d_api.py b/test/legacy_test/test_pool1d_api.py index dcb897973e7868..c4aa624bd76626 100644 --- a/test/legacy_test/test_pool1d_api.py +++ b/test/legacy_test/test_pool1d_api.py @@ -301,7 +301,12 @@ def check_max_pool_return_mask_ceil(self, place): input_np = np.random.random([1, 3, 6]).astype("float32") input = paddle.to_tensor(input_np) result, _ = F.max_pool1d( - input, kernel_size=5, ceil_mode=True, return_mask=True + input, + kernel_size=5, + stride=5, + padding=0, + ceil_mode=True, + return_mask=True, ) result_np = max_pool1D_forward_naive( input_np, @@ -311,6 +316,7 @@ def check_max_pool_return_mask_ceil(self, place): ceil_mode=True, ) np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + self.assertEqual(result.shape, list(result_np.shape)) def test_pool1d(self): for place in self.places: diff --git a/test/legacy_test/test_pool2d_api.py b/test/legacy_test/test_pool2d_api.py index 0bd1cae2f7e0de..bb02bbcce1208d 100644 --- a/test/legacy_test/test_pool2d_api.py +++ b/test/legacy_test/test_pool2d_api.py @@ -362,7 +362,12 @@ def check_max_pool_return_mask_ceil(self, place): input_np = np.random.random([2, 3, 33, 33]).astype("float32") input = paddle.to_tensor(input_np) result, _ = max_pool2d( - input, kernel_size=5, ceil_mode=True, return_mask=True + input, + kernel_size=5, + stride=5, + padding=0, + ceil_mode=True, + return_mask=True, ) result_np = pool2D_forward_naive( input_np, @@ -372,6 +377,7 @@ def check_max_pool_return_mask_ceil(self, place): ceil_mode=True, ) np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + self.assertEqual(result.shape, list(result_np.shape)) def test_pool2d(self): for place in self.places: diff --git a/test/legacy_test/test_pool3d_api.py b/test/legacy_test/test_pool3d_api.py index 4e7f4bb032ca3e..d1eb706f953bd7 100644 --- a/test/legacy_test/test_pool3d_api.py +++ b/test/legacy_test/test_pool3d_api.py @@ -361,7 +361,12 @@ def check_max_pool_return_mask_ceil(self, place): input_np = np.random.random([1, 2, 6, 33, 33]).astype("float32") input = paddle.to_tensor(input_np) result, _ = max_pool3d( - input, kernel_size=5, ceil_mode=True, return_mask=True + input, + kernel_size=5, + stride=5, + padding=0, + ceil_mode=True, + return_mask=True, ) result_np = pool3D_forward_naive( input_np, @@ -371,6 +376,7 @@ def check_max_pool_return_mask_ceil(self, place): ceil_mode=True, ) np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + self.assertEqual(result.shape, list(result_np.shape)) def test_pool3d(self): paddle.enable_static() From 04631137ad86b046e71ebeb25c33b6fa3b7242e5 Mon Sep 17 00:00:00 2001 From: PuQing Date: Mon, 17 Jun 2024 14:44:43 +0000 Subject: [PATCH 8/9] Add ceil_mode argument to max_pool2d_with_index and max_pool3d_with_index --- paddle/phi/ops/yaml/backward.yaml | 8 ++++---- paddle/phi/ops/yaml/ops.yaml | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 39cd4ef9b073d7..9a72ba5026a653 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1918,8 +1918,8 @@ composite : max_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad) - backward_op : max_pool2d_with_index_grad - forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(mask) - args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool ceil_mode = false) output : Tensor(x_grad) infer_meta : func : MaxPoolWithIndexGradInferMeta @@ -1927,8 +1927,8 @@ func : max_pool2d_with_index_grad - backward_op : max_pool3d_with_index_grad - forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(mask) - args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool ceil_mode = false) output : Tensor(x_grad) infer_meta : func : MaxPoolWithIndexGradInferMeta diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 45b7c7b8964b96..68930edd532753 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -2823,7 +2823,7 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface - op : max_pool2d_with_index - args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false) + args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) output : Tensor(out), Tensor(mask) infer_meta : func : MaxPoolWithIndexInferMeta @@ -2832,7 +2832,7 @@ backward : max_pool2d_with_index_grad - op : max_pool3d_with_index - args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false) + args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) output : Tensor(out), Tensor(mask) infer_meta : func : MaxPoolWithIndexInferMeta From 1ebbd24be08aab2b10e9565fc099802a6b52fa28 Mon Sep 17 00:00:00 2001 From: PuQing Date: Mon, 1 Jul 2024 04:01:52 +0000 Subject: [PATCH 9/9] Add ceil_mode arg to max_pool2d_with_index and max_pool3d_with_index --- paddle/phi/ops/yaml/op_version.yaml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/paddle/phi/ops/yaml/op_version.yaml b/paddle/phi/ops/yaml/op_version.yaml index a41a67e9ded17d..b6081079c4a328 100644 --- a/paddle/phi/ops/yaml/op_version.yaml +++ b/paddle/phi/ops/yaml/op_version.yaml @@ -392,6 +392,22 @@ - add_output : RoisNum comment : The number of RoIs in each image. +- op : max_pool2d_with_index + version : + - checkpoint : Upgrade max_pool2d_with_index, add a new attribute [ceil_mode]. + action : + - add_attr : ceil_mode + comment : When true, will use ceil instead of floor to compute the output shape. + default : "false" + +- op : max_pool3d_with_index + version : + - checkpoint : Upgrade max_pool3d_with_index, add a new attribute [ceil_mode]. + action : + - add_attr : ceil_mode + comment : When true, will use ceil instead of floor to compute the output shape. + default : "false" + - op : momentum version : - checkpoint : Upgrade momentum add 4 attributes [regularization_method, regularization_coeff, multi_precision, rescale_grad].