From 0197dca6c37cc437a5bd88a5f1185bb010886bb6 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 14 May 2024 03:04:42 +0000 Subject: [PATCH 1/5] [Prim]Add more detail note for start/end index test=document_fix --- python/paddle/decomposition/decomp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index ba6dc8e862e2a..dd88704a77ce2 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -213,7 +213,7 @@ def decompose( blacklist (frozenset): The Operators that will be exclude when decomposed into primitives. whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives. start_index (int): The start index of decomposed operator in global block, default 0; - end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. + end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. start_index and end_index follow the principle of left open and right closed, that is [start_index, end_index). Returns: dst_vars (list): A list contains all vars which replace origin ones in src_vars. From ad1b9b11e7e9d04a30454d68af731aa51b92712b Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 14 May 2024 10:02:50 +0000 Subject: [PATCH 2/5] migrate UT --- paddle/fluid/primitive/codegen/gen.py | 2 + paddle/fluid/primitive/rule/vjp/details.h | 45 ++++ paddle/phi/infermeta/unary.cc | 10 - python/paddle/decomposition/decomp.py | 2 +- test/deprecated/legacy_test/CMakeLists.txt | 2 - .../legacy_test/test_squeeze2_op.py | 206 ---------------- test/legacy_test/CMakeLists.txt | 2 + test/legacy_test/test_squeeze2_op.py | 228 ++++++++++++++++++ .../legacy_test/test_unsqueeze2_op.py | 6 +- 9 files changed, 280 insertions(+), 223 deletions(-) mode change 100755 => 100644 test/deprecated/legacy_test/test_squeeze2_op.py create mode 100755 test/legacy_test/test_squeeze2_op.py rename test/{deprecated => }/legacy_test/test_unsqueeze2_op.py (98%) diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 90970ba3c3119..60131dda70b10 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -101,8 +101,10 @@ 'scatter_grad', 'scatter_nd_add_grad', 'slice_grad', + 'squeeze_grad', 'tile_grad', 'topk_grad', + 'unsqueeze_grad', ] # whole vjp list of primitive op vjp diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 3e577c225387c..59c031952ee7f 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -877,6 +877,51 @@ void softmax_grad(const Tensor& out, } } +template +void squeeze_grad(const Tensor& xshape, + const Tensor& out_grad, + const IntArray& axis, + Tensor* x_grad) { + if (x_grad) { + auto x_grad_out = unsqueeze(out_grad, axis); + set_output(x_grad_out, x_grad); + } +} + +template +void unsqueeze_grad(const Tensor& xshape, + const Tensor& out_grad, + const IntArray& axis, + Tensor* x_grad) { + // for xshape = [10, 2, 5], axis = [3, 1, 1], out_grad.shape = [10, 1, 1, 2, + // 5, 1], it outputs squeeze axis = [5, 2, 1] + const auto& IncreaseAxis = [](std::vector* axis_data, + int64_t pivot) { + for (size_t i = 0; i < axis_data->size(); ++i) { + if ((*axis_data)[i] >= pivot) (*axis_data)[i] += 1; + } + }; + const auto& GetRealAxis = [&](const IntArray& axis) -> decltype(auto) { + // for axis = [0, 3, 3], it outputs [0, 3, 3+1], because unsqueeze support + // duplicated axis. + std::vector output_axis; + const int64_t x_rank = xshape.dims().size() - 1; + const std::vector axis_data = axis.GetData(); + for (size_t i = 0; i < axis_data.size(); ++i) { + int64_t value = axis_data[i]; + if (value < 0) value += (x_rank + i + 1); + IncreaseAxis(&output_axis, value); + output_axis.push_back(value); + } + return output_axis; + }; + + if (x_grad) { + auto x_grad_out = squeeze(out_grad, GetRealAxis(axis)); + set_output(x_grad_out, x_grad); + } +} + template void matmul_grad(const Tensor& x, const Tensor& y, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 90df04d732204..042e2bc7af73e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4286,16 +4286,6 @@ void SqueezeInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) { const auto& x_dims = x.dims(); - // Check input tensor dims (<6) Eigen limit. - PADDLE_ENFORCE_LE(x_dims.size(), - 6, - phi::errors::InvalidArgument( - "The dimensions of Input(X) " - "should be in the range of [1, 6] (Eigen limit)." - "But received X's dimensions = %d, X's shape = [%s].", - x_dims.size(), - x_dims)); - if (!config.is_runtime && axes.FromTensor()) { // compile time infershape, set all elements to -1. int output_size = static_cast(x.dims().size() - axes.GetData().size()); diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index dd88704a77ce2..d613be161b681 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -213,7 +213,7 @@ def decompose( blacklist (frozenset): The Operators that will be exclude when decomposed into primitives. whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives. start_index (int): The start index of decomposed operator in global block, default 0; - end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. start_index and end_index follow the principle of left open and right closed, that is [start_index, end_index). + end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. start_index and end_index follow the principle of left closed and right open, that is [start_index, end_index). Returns: dst_vars (list): A list contains all vars which replace origin ones in src_vars. diff --git a/test/deprecated/legacy_test/CMakeLists.txt b/test/deprecated/legacy_test/CMakeLists.txt index 4968e979a137a..93a47e5ef93d4 100644 --- a/test/deprecated/legacy_test/CMakeLists.txt +++ b/test/deprecated/legacy_test/CMakeLists.txt @@ -772,11 +772,9 @@ set(TEST_CINN_OPS test_top_k_v2_op test_elementwise_mul_op test_gather_nd_op - test_squeeze2_op test_elementwise_pow_op test_transpose_op test_reshape_op - test_unsqueeze2_op test_meshgrid_op test_scale_op test_scatter_op diff --git a/test/deprecated/legacy_test/test_squeeze2_op.py b/test/deprecated/legacy_test/test_squeeze2_op.py old mode 100755 new mode 100644 index b462d639a6703..ed347eda7350b --- a/test/deprecated/legacy_test/test_squeeze2_op.py +++ b/test/deprecated/legacy_test/test_squeeze2_op.py @@ -15,182 +15,14 @@ import os import unittest -import numpy as np -from op_test import OpTest, convert_float_to_uint16 from test_attribute_var import UnittestBase import paddle -from paddle.base import core from paddle.base.framework import Program, program_guard -from paddle.pir_utils import test_with_pir_api paddle.enable_static() -# Correct: General. -class TestSqueezeOp(OpTest): - def setUp(self): - self.op_type = "squeeze2" - self.prim_op_type = "comp" - self.python_api = paddle.squeeze - self.public_python_api = paddle.squeeze - self.python_out_sig = [ - "Out" - ] # python out sig is customized output signature. - self.init_test_case() - self.init_dtype() - self.if_enable_cinn() - x = np.random.random(self.ori_shape).astype("float64") - xshape = np.random.random(self.ori_shape).astype("float64") - if hasattr(self, "dtype") and self.dtype == np.uint16: - x = convert_float_to_uint16(x.astype(np.float32)) - xshape = convert_float_to_uint16(xshape.astype(np.float32)) - self.inputs = {"X": x} - self.init_attrs() - self.outputs = { - "Out": self.inputs["X"].reshape(self.new_shape), - "XShape": xshape, - } - - def if_enable_cinn(self): - pass - - def test_check_output(self): - self.check_output( - no_check_set=['XShape'], - check_prim=True, - check_pir=True, - check_prim_pir=True, - ) - - def test_check_grad(self): - self.check_grad( - ["X"], - "Out", - check_prim=True, - check_pir=True, - check_prim_pir=True, - ) - - def init_dtype(self): - self.dtype = np.float64 - - def init_test_case(self): - self.ori_shape = (1, 3, 1, 40) - self.axes = (0, 2) - self.new_shape = (3, 40) - - def init_attrs(self): - self.attrs = {"axes": self.axes} - - -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "core is not compiled with CUDA and do not support bfloat16", -) -class TestSqueezeOpBF16OP(TestSqueezeOp): - def init_dtype(self): - self.dtype = np.uint16 - - -# Correct: There is mins axis. -class TestSqueezeOp1(TestSqueezeOp): - def init_test_case(self): - self.ori_shape = (1, 20, 1, 5) - self.axes = (0, -2) - self.new_shape = (20, 5) - - -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "core is not compiled with CUDA and do not support bfloat16", -) -class TestSqueezeOp1BF16Op(TestSqueezeOp): - def init_dtype(self): - self.dtype = np.uint16 - - -class TestSqueezeOp_ZeroDim1(TestSqueezeOp): - def init_test_case(self): - self.ori_shape = () - self.axes = (0,) - self.new_shape = () - - -class TestSqueezeOp_ZeroDim2(TestSqueezeOp): - def init_test_case(self): - self.ori_shape = (1, 1, 1) - self.axes = (0, 1, 2) - self.new_shape = () - - -# Correct: No axes input. -class TestSqueezeOp2(TestSqueezeOp): - def setUp(self): - self.op_type = "squeeze2" - self.prim_op_type = "comp" - self.python_api = paddle.squeeze - self.public_python_api = paddle.squeeze - self.python_out_sig = [ - "Out" - ] # python out sig is customized output signature. - self.init_test_case() - self.init_dtype() - self.if_enable_cinn() - x = np.random.random(self.ori_shape).astype("float64") - xshape = np.random.random(self.ori_shape).astype("float64") - if hasattr(self, "dtype") and self.dtype == np.uint16: - x = convert_float_to_uint16(x.astype(np.float32)) - xshape = convert_float_to_uint16(xshape.astype(np.float32)) - self.inputs = {"X": x} - self.init_attrs() - self.outputs = { - "Out": self.inputs["X"].reshape(self.new_shape), - "XShape": xshape, - } - - def if_enable_cinn(self): - pass - - def init_dtype(self): - self.dtype = np.float64 - - def init_test_case(self): - self.ori_shape = (1, 20, 1, 5) - self.axes = () - self.new_shape = (20, 5) - - -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "core is not compiled with CUDA and do not support bfloat16", -) -class TestSqueezeOp2BF16Op(TestSqueezeOp): - def init_dtype(self): - self.dtype = np.uint16 - - -# Correct: Just part of axes be squeezed. -class TestSqueezeOp3(TestSqueezeOp): - def init_test_case(self): - self.ori_shape = (6, 1, 5, 1, 4, 1) - self.axes = (1, -1) - self.new_shape = (6, 5, 1, 4) - - -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "core is not compiled with CUDA and do not support bfloat16", -) -class TestSqueezeOp3BF16Op(TestSqueezeOp): - def init_dtype(self): - self.dtype = np.uint16 - - class TestSqueeze2AxesTensor(UnittestBase): def init_info(self): self.shapes = [[2, 3, 4]] @@ -266,43 +98,5 @@ def test_static(self): self.assertEqual(infer_out.shape, (2, 3, 10)) -# test api -class TestSqueezeAPI(unittest.TestCase): - def setUp(self): - self.executed_api() - - def executed_api(self): - self.squeeze = paddle.squeeze - - def test_api(self): - paddle.disable_static() - input_data = np.random.random([3, 2, 1]).astype("float32") - x = paddle.to_tensor(input_data) - out = self.squeeze(x, axis=2) - out.backward() - - self.assertEqual(out.shape, [3, 2]) - - paddle.enable_static() - - @test_with_pir_api - def test_error(self): - def test_axes_type(): - with paddle.static.program_guard( - paddle.static.Program(), paddle.static.Program() - ): - x2 = paddle.static.data( - name="x2", shape=[2, 1, 25], dtype="int32" - ) - self.squeeze(x2, axis=2.1) - - self.assertRaises(TypeError, test_axes_type) - - -class TestSqueezeInplaceAPI(TestSqueezeAPI): - def executed_api(self): - self.squeeze = paddle.squeeze_ - - if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 767775eb17db4..845870d4b787d 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -943,6 +943,8 @@ set(TEST_CINN_OPS test_group_norm_op test_tile_op test_sum_op + test_squeeze2_op + test_unsqueeze2_op test_elementwise_min_op test_take_along_axis_op test_strided_slice_op diff --git a/test/legacy_test/test_squeeze2_op.py b/test/legacy_test/test_squeeze2_op.py new file mode 100755 index 0000000000000..a2f549d415fc6 --- /dev/null +++ b/test/legacy_test/test_squeeze2_op.py @@ -0,0 +1,228 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest, convert_float_to_uint16 + +import paddle +from paddle.base import core +from paddle.pir_utils import test_with_pir_api + +paddle.enable_static() + + +# Correct: General. +class TestSqueezeOp(OpTest): + def setUp(self): + self.op_type = "squeeze2" + self.prim_op_type = "prim" + self.python_api = paddle.squeeze + self.public_python_api = paddle.squeeze + self.python_out_sig = [ + "Out" + ] # python out sig is customized output signature. + self.init_test_case() + self.init_dtype() + self.if_enable_cinn() + x = np.random.random(self.ori_shape).astype("float64") + xshape = np.random.random(self.ori_shape).astype("float64") + if hasattr(self, "dtype") and self.dtype == np.uint16: + x = convert_float_to_uint16(x.astype(np.float32)) + xshape = convert_float_to_uint16(xshape.astype(np.float32)) + self.inputs = {"X": x} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": xshape, + } + + def if_enable_cinn(self): + pass + + def test_check_output(self): + self.check_output( + no_check_set=['XShape'], + check_pir=True, + check_prim_pir=True, + ) + + def test_check_grad(self): + self.check_grad( + ["X"], + "Out", + check_pir=True, + check_prim_pir=True, + ) + + def init_dtype(self): + self.dtype = np.float64 + + def init_test_case(self): + self.ori_shape = (1, 3, 1, 40) + self.axes = (0, 2) + self.new_shape = (3, 40) + + def init_attrs(self): + self.attrs = {"axes": self.axes} + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestSqueezeOpBF16OP(TestSqueezeOp): + def init_dtype(self): + self.dtype = np.uint16 + + +# Correct: There is mins axis. +class TestSqueezeOp1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 20, 1, 5) + self.axes = (0, -2) + self.new_shape = (20, 5) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestSqueezeOp1BF16Op(TestSqueezeOp): + def init_dtype(self): + self.dtype = np.uint16 + + +class TestSqueezeOp_ZeroDim1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = () + self.axes = (0,) + self.new_shape = () + + +class TestSqueezeOp_ZeroDim2(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 1, 1) + self.axes = (0, 1, 2) + self.new_shape = () + + +# Correct: No axes input. +class TestSqueezeOp2(TestSqueezeOp): + def setUp(self): + self.op_type = "squeeze2" + self.prim_op_type = "comp" + self.python_api = paddle.squeeze + self.public_python_api = paddle.squeeze + self.python_out_sig = [ + "Out" + ] # python out sig is customized output signature. + self.init_test_case() + self.init_dtype() + self.if_enable_cinn() + x = np.random.random(self.ori_shape).astype("float64") + xshape = np.random.random(self.ori_shape).astype("float64") + if hasattr(self, "dtype") and self.dtype == np.uint16: + x = convert_float_to_uint16(x.astype(np.float32)) + xshape = convert_float_to_uint16(xshape.astype(np.float32)) + self.inputs = {"X": x} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": xshape, + } + + def if_enable_cinn(self): + pass + + def init_dtype(self): + self.dtype = np.float64 + + def init_test_case(self): + self.ori_shape = (1, 20, 1, 5) + self.axes = () + self.new_shape = (20, 5) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestSqueezeOp2BF16Op(TestSqueezeOp): + def init_dtype(self): + self.dtype = np.uint16 + + +# Correct: Just part of axes be squeezed. +class TestSqueezeOp3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (6, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (6, 5, 1, 4) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestSqueezeOp3BF16Op(TestSqueezeOp): + def init_dtype(self): + self.dtype = np.uint16 + + +# test api +class TestSqueezeAPI(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.squeeze = paddle.squeeze + + def test_api(self): + paddle.disable_static() + input_data = np.random.random([3, 2, 1]).astype("float32") + x = paddle.to_tensor(input_data) + out = self.squeeze(x, axis=2) + out.backward() + + self.assertEqual(out.shape, [3, 2]) + + paddle.enable_static() + + @test_with_pir_api + def test_error(self): + def test_axes_type(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x2 = paddle.static.data( + name="x2", shape=[2, 1, 25], dtype="int32" + ) + self.squeeze(x2, axis=2.1) + + self.assertRaises(TypeError, test_axes_type) + + +class TestSqueezeInplaceAPI(TestSqueezeAPI): + def executed_api(self): + self.squeeze = paddle.squeeze_ + + +if __name__ == "__main__": + unittest.main() diff --git a/test/deprecated/legacy_test/test_unsqueeze2_op.py b/test/legacy_test/test_unsqueeze2_op.py similarity index 98% rename from test/deprecated/legacy_test/test_unsqueeze2_op.py rename to test/legacy_test/test_unsqueeze2_op.py index 65b7420c02d52..2a0c0ec1620ff 100755 --- a/test/deprecated/legacy_test/test_unsqueeze2_op.py +++ b/test/legacy_test/test_unsqueeze2_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ def setUp(self): "Out": self.inputs["X"].reshape(self.new_shape), "XShape": np.random.random(self.ori_shape).astype("float64"), } - self.prim_op_type = "comp" + self.prim_op_type = "prim" self.if_enable_cinn() def if_enable_cinn(self): @@ -46,7 +46,6 @@ def if_enable_cinn(self): def test_check_output(self): self.check_output( no_check_set=["XShape"], - check_prim=True, check_pir=True, check_prim_pir=True, ) @@ -55,7 +54,6 @@ def test_check_grad(self): self.check_grad( ["X"], "Out", - check_prim=True, check_pir=True, check_prim_pir=True, ) From 6d3263a592754b0c4cf6f6753ce2d7b2910511ac Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 15 May 2024 01:57:54 +0000 Subject: [PATCH 3/5] rename UT --- .../{test_squeeze2_op.py => test_squeeze2_op_rename.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/deprecated/legacy_test/{test_squeeze2_op.py => test_squeeze2_op_rename.py} (100%) diff --git a/test/deprecated/legacy_test/test_squeeze2_op.py b/test/deprecated/legacy_test/test_squeeze2_op_rename.py similarity index 100% rename from test/deprecated/legacy_test/test_squeeze2_op.py rename to test/deprecated/legacy_test/test_squeeze2_op_rename.py From a64fe98ca194be51c385eb95505b57a49190d4fd Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 15 May 2024 03:49:08 +0000 Subject: [PATCH 4/5] fix timeout --- test/deprecated/legacy_test/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/deprecated/legacy_test/CMakeLists.txt b/test/deprecated/legacy_test/CMakeLists.txt index 93a47e5ef93d4..0ff6d25736133 100644 --- a/test/deprecated/legacy_test/CMakeLists.txt +++ b/test/deprecated/legacy_test/CMakeLists.txt @@ -750,6 +750,7 @@ set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60) set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 120) set_tests_properties(test_model PROPERTIES TIMEOUT 300) set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600) +set_tests_properties(test_squeeze2_op_rename PROPERTIES TIMEOUT 120) if(APPLE) set_tests_properties(test_callback_early_stop PROPERTIES TIMEOUT 300) From 7adae97ff63a4fc22fedf723251636fa9c1567be Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 15 May 2024 07:22:45 +0000 Subject: [PATCH 5/5] fix atol --- test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py b/test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py index 088def7f45007..f0cc7826e933d 100644 --- a/test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py +++ b/test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py @@ -114,7 +114,7 @@ def test_ast_prim_cinn(self): for st, cinn in zip( paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out) ): - np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-8) + np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6) if __name__ == '__main__':