diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 90970ba3c3119..60131dda70b10 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -101,8 +101,10 @@ 'scatter_grad', 'scatter_nd_add_grad', 'slice_grad', + 'squeeze_grad', 'tile_grad', 'topk_grad', + 'unsqueeze_grad', ] # whole vjp list of primitive op vjp diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 3e577c225387c..59c031952ee7f 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -877,6 +877,51 @@ void softmax_grad(const Tensor& out, } } +template +void squeeze_grad(const Tensor& xshape, + const Tensor& out_grad, + const IntArray& axis, + Tensor* x_grad) { + if (x_grad) { + auto x_grad_out = unsqueeze(out_grad, axis); + set_output(x_grad_out, x_grad); + } +} + +template +void unsqueeze_grad(const Tensor& xshape, + const Tensor& out_grad, + const IntArray& axis, + Tensor* x_grad) { + // for xshape = [10, 2, 5], axis = [3, 1, 1], out_grad.shape = [10, 1, 1, 2, + // 5, 1], it outputs squeeze axis = [5, 2, 1] + const auto& IncreaseAxis = [](std::vector* axis_data, + int64_t pivot) { + for (size_t i = 0; i < axis_data->size(); ++i) { + if ((*axis_data)[i] >= pivot) (*axis_data)[i] += 1; + } + }; + const auto& GetRealAxis = [&](const IntArray& axis) -> decltype(auto) { + // for axis = [0, 3, 3], it outputs [0, 3, 3+1], because unsqueeze support + // duplicated axis. + std::vector output_axis; + const int64_t x_rank = xshape.dims().size() - 1; + const std::vector axis_data = axis.GetData(); + for (size_t i = 0; i < axis_data.size(); ++i) { + int64_t value = axis_data[i]; + if (value < 0) value += (x_rank + i + 1); + IncreaseAxis(&output_axis, value); + output_axis.push_back(value); + } + return output_axis; + }; + + if (x_grad) { + auto x_grad_out = squeeze(out_grad, GetRealAxis(axis)); + set_output(x_grad_out, x_grad); + } +} + template void matmul_grad(const Tensor& x, const Tensor& y, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 90df04d732204..042e2bc7af73e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4286,16 +4286,6 @@ void SqueezeInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) { const auto& x_dims = x.dims(); - // Check input tensor dims (<6) Eigen limit. - PADDLE_ENFORCE_LE(x_dims.size(), - 6, - phi::errors::InvalidArgument( - "The dimensions of Input(X) " - "should be in the range of [1, 6] (Eigen limit)." - "But received X's dimensions = %d, X's shape = [%s].", - x_dims.size(), - x_dims)); - if (!config.is_runtime && axes.FromTensor()) { // compile time infershape, set all elements to -1. int output_size = static_cast(x.dims().size() - axes.GetData().size()); diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index ba6dc8e862e2a..d613be161b681 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -213,7 +213,7 @@ def decompose( blacklist (frozenset): The Operators that will be exclude when decomposed into primitives. whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives. start_index (int): The start index of decomposed operator in global block, default 0; - end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. + end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. start_index and end_index follow the principle of left closed and right open, that is [start_index, end_index). Returns: dst_vars (list): A list contains all vars which replace origin ones in src_vars. diff --git a/test/deprecated/legacy_test/CMakeLists.txt b/test/deprecated/legacy_test/CMakeLists.txt index 4968e979a137a..0ff6d25736133 100644 --- a/test/deprecated/legacy_test/CMakeLists.txt +++ b/test/deprecated/legacy_test/CMakeLists.txt @@ -750,6 +750,7 @@ set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60) set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 120) set_tests_properties(test_model PROPERTIES TIMEOUT 300) set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600) +set_tests_properties(test_squeeze2_op_rename PROPERTIES TIMEOUT 120) if(APPLE) set_tests_properties(test_callback_early_stop PROPERTIES TIMEOUT 300) @@ -772,11 +773,9 @@ set(TEST_CINN_OPS test_top_k_v2_op test_elementwise_mul_op test_gather_nd_op - test_squeeze2_op test_elementwise_pow_op test_transpose_op test_reshape_op - test_unsqueeze2_op test_meshgrid_op test_scale_op test_scatter_op diff --git a/test/deprecated/legacy_test/test_squeeze2_op_rename.py b/test/deprecated/legacy_test/test_squeeze2_op_rename.py new file mode 100644 index 0000000000000..ed347eda7350b --- /dev/null +++ b/test/deprecated/legacy_test/test_squeeze2_op_rename.py @@ -0,0 +1,102 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from test_attribute_var import UnittestBase + +import paddle +from paddle.base.framework import Program, program_guard + +paddle.enable_static() + + +class TestSqueeze2AxesTensor(UnittestBase): + def init_info(self): + self.shapes = [[2, 3, 4]] + self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor') + + def test_static(self): + main_prog = Program() + startup_prog = Program() + with program_guard(main_prog, startup_prog): + fc = paddle.nn.Linear(4, 10) + x = paddle.randn([2, 3, 4]) + x.stop_gradient = False + feat = fc(x) # [2,3,10] + feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10] + # axes is a Variable + axes = paddle.assign([0, 2]) + out = paddle.squeeze(feat, axes) + out2 = paddle.squeeze(feat, axes) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(out)) + self.assertTrue("Var[" in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(startup_prog) + res = exe.run(fetch_list=[feat, out, out2]) + self.assertEqual(res[0].shape, (1, 2, 1, 3, 10)) + self.assertEqual(res[1].shape, (2, 3, 10)) + self.assertEqual(res[2].shape, (2, 3, 10)) + + paddle.static.save_inference_model(self.save_path, [x], [out], exe) + # Test for Inference Predictor + infer_out = self.infer_prog() + self.assertEqual(infer_out.shape, (2, 3, 10)) + + +class TestSqueeze2AxesTensorList(UnittestBase): + def init_info(self): + self.shapes = [[2, 3, 4]] + self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor') + + def test_static(self): + main_prog = Program() + startup_prog = Program() + with program_guard(main_prog, startup_prog): + fc = paddle.nn.Linear(4, 10) + x = paddle.randn([2, 3, 4]) + x.stop_gradient = False + feat = fc(x) # [2,3,10] + feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10] + # axes is a list[Variable] + axes = [ + paddle.full([1], 0, dtype='int32'), + paddle.full([1], 2, dtype='int32'), + ] + out = paddle.squeeze(feat, axes) + out2 = paddle.squeeze(feat, axes) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(out)) + self.assertTrue("Vars[" in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(startup_prog) + res = exe.run(fetch_list=[feat, out, out2]) + self.assertEqual(res[0].shape, (1, 2, 1, 3, 10)) + self.assertEqual(res[1].shape, (2, 3, 10)) + self.assertEqual(res[2].shape, (2, 3, 10)) + + paddle.static.save_inference_model(self.save_path, [x], [out], exe) + # Test for Inference Predictor + infer_out = self.infer_prog() + self.assertEqual(infer_out.shape, (2, 3, 10)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py b/test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py index 088def7f45007..f0cc7826e933d 100644 --- a/test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py +++ b/test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py @@ -114,7 +114,7 @@ def test_ast_prim_cinn(self): for st, cinn in zip( paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out) ): - np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-8) + np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6) if __name__ == '__main__': diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 767775eb17db4..845870d4b787d 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -943,6 +943,8 @@ set(TEST_CINN_OPS test_group_norm_op test_tile_op test_sum_op + test_squeeze2_op + test_unsqueeze2_op test_elementwise_min_op test_take_along_axis_op test_strided_slice_op diff --git a/test/deprecated/legacy_test/test_squeeze2_op.py b/test/legacy_test/test_squeeze2_op.py similarity index 67% rename from test/deprecated/legacy_test/test_squeeze2_op.py rename to test/legacy_test/test_squeeze2_op.py index b462d639a6703..a2f549d415fc6 100755 --- a/test/deprecated/legacy_test/test_squeeze2_op.py +++ b/test/legacy_test/test_squeeze2_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import numpy as np from op_test import OpTest, convert_float_to_uint16 -from test_attribute_var import UnittestBase import paddle from paddle.base import core -from paddle.base.framework import Program, program_guard from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -31,7 +28,7 @@ class TestSqueezeOp(OpTest): def setUp(self): self.op_type = "squeeze2" - self.prim_op_type = "comp" + self.prim_op_type = "prim" self.python_api = paddle.squeeze self.public_python_api = paddle.squeeze self.python_out_sig = [ @@ -58,7 +55,6 @@ def if_enable_cinn(self): def test_check_output(self): self.check_output( no_check_set=['XShape'], - check_prim=True, check_pir=True, check_prim_pir=True, ) @@ -67,7 +63,6 @@ def test_check_grad(self): self.check_grad( ["X"], "Out", - check_prim=True, check_pir=True, check_prim_pir=True, ) @@ -191,81 +186,6 @@ def init_dtype(self): self.dtype = np.uint16 -class TestSqueeze2AxesTensor(UnittestBase): - def init_info(self): - self.shapes = [[2, 3, 4]] - self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor') - - def test_static(self): - main_prog = Program() - startup_prog = Program() - with program_guard(main_prog, startup_prog): - fc = paddle.nn.Linear(4, 10) - x = paddle.randn([2, 3, 4]) - x.stop_gradient = False - feat = fc(x) # [2,3,10] - feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10] - # axes is a Variable - axes = paddle.assign([0, 2]) - out = paddle.squeeze(feat, axes) - out2 = paddle.squeeze(feat, axes) - - sgd = paddle.optimizer.SGD() - sgd.minimize(paddle.mean(out)) - self.assertTrue("Var[" in str(main_prog)) - - exe = paddle.static.Executor() - exe.run(startup_prog) - res = exe.run(fetch_list=[feat, out, out2]) - self.assertEqual(res[0].shape, (1, 2, 1, 3, 10)) - self.assertEqual(res[1].shape, (2, 3, 10)) - self.assertEqual(res[2].shape, (2, 3, 10)) - - paddle.static.save_inference_model(self.save_path, [x], [out], exe) - # Test for Inference Predictor - infer_out = self.infer_prog() - self.assertEqual(infer_out.shape, (2, 3, 10)) - - -class TestSqueeze2AxesTensorList(UnittestBase): - def init_info(self): - self.shapes = [[2, 3, 4]] - self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor') - - def test_static(self): - main_prog = Program() - startup_prog = Program() - with program_guard(main_prog, startup_prog): - fc = paddle.nn.Linear(4, 10) - x = paddle.randn([2, 3, 4]) - x.stop_gradient = False - feat = fc(x) # [2,3,10] - feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10] - # axes is a list[Variable] - axes = [ - paddle.full([1], 0, dtype='int32'), - paddle.full([1], 2, dtype='int32'), - ] - out = paddle.squeeze(feat, axes) - out2 = paddle.squeeze(feat, axes) - - sgd = paddle.optimizer.SGD() - sgd.minimize(paddle.mean(out)) - self.assertTrue("Vars[" in str(main_prog)) - - exe = paddle.static.Executor() - exe.run(startup_prog) - res = exe.run(fetch_list=[feat, out, out2]) - self.assertEqual(res[0].shape, (1, 2, 1, 3, 10)) - self.assertEqual(res[1].shape, (2, 3, 10)) - self.assertEqual(res[2].shape, (2, 3, 10)) - - paddle.static.save_inference_model(self.save_path, [x], [out], exe) - # Test for Inference Predictor - infer_out = self.infer_prog() - self.assertEqual(infer_out.shape, (2, 3, 10)) - - # test api class TestSqueezeAPI(unittest.TestCase): def setUp(self): diff --git a/test/deprecated/legacy_test/test_unsqueeze2_op.py b/test/legacy_test/test_unsqueeze2_op.py similarity index 98% rename from test/deprecated/legacy_test/test_unsqueeze2_op.py rename to test/legacy_test/test_unsqueeze2_op.py index 65b7420c02d52..2a0c0ec1620ff 100755 --- a/test/deprecated/legacy_test/test_unsqueeze2_op.py +++ b/test/legacy_test/test_unsqueeze2_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ def setUp(self): "Out": self.inputs["X"].reshape(self.new_shape), "XShape": np.random.random(self.ori_shape).astype("float64"), } - self.prim_op_type = "comp" + self.prim_op_type = "prim" self.if_enable_cinn() def if_enable_cinn(self): @@ -46,7 +46,6 @@ def if_enable_cinn(self): def test_check_output(self): self.check_output( no_check_set=["XShape"], - check_prim=True, check_pir=True, check_prim_pir=True, ) @@ -55,7 +54,6 @@ def test_check_grad(self): self.check_grad( ["X"], "Out", - check_prim=True, check_pir=True, check_prim_pir=True, )