diff --git a/paddle/cinn/hlir/op/elementwise.cc b/paddle/cinn/hlir/op/elementwise.cc index 0294fa6cf6f8a4..c225ea48118034 100644 --- a/paddle/cinn/hlir/op/elementwise.cc +++ b/paddle/cinn/hlir/op/elementwise.cc @@ -705,10 +705,6 @@ std::vector> InferShapeForSqueeze( VLOG(4) << "The output calculated in Squeeze: " << cinn::utils::Join(output_shape, ", "); - - if (output_shape.size() == 0) { - output_shape.push_back(1); - } return {output_shape}; } diff --git a/paddle/cinn/utils/functional.cc b/paddle/cinn/utils/functional.cc index 9fd5799bc6e875..a71d73d41f8373 100644 --- a/paddle/cinn/utils/functional.cc +++ b/paddle/cinn/utils/functional.cc @@ -23,7 +23,7 @@ std::vector GetPositiveAxes(const std::vector& axes, int rank) { std::vector new_axes(axes.size()); for (int i = 0; i < axes.size(); ++i) { int axis = axes[i] + (axes[i] < 0 ? rank : 0); - CHECK(axis >= 0 && axis < rank) + CHECK(axis >= 0 && (rank == 0 || axis < rank)) << "The axis should in [" << -rank << ", " << rank << "), but axes[" << i << "]=" << axes[i] << " not."; new_axes[i] = axis; diff --git a/test/cinn/ops/test_zero_dim_tensor.py b/test/cinn/ops/test_zero_dim_tensor.py index 019b7638f7ef7f..99b13247ea1f59 100644 --- a/test/cinn/ops/test_zero_dim_tensor.py +++ b/test/cinn/ops/test_zero_dim_tensor.py @@ -713,5 +713,62 @@ def test_check_results(self): self.check_outputs_and_grads() +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestSqueezeOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.squeeze_axex = [0] + self.target_shape = () + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + out = paddle.squeeze(x, axis=self.squeeze_axex) + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("squeeze_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + out = builder.squeeze(x, self.squeeze_axex) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out]) + + self.cinn_outputs = res + self.assertEqual(res[0].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestSqueezeOp1D(TestSqueezeOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, [1]).astype(self.dtype), + } + self.squeeze_axex = [] + self.target_shape = () + + +class TestSqueezeOp2D(TestSqueezeOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, [1, 1]).astype(self.dtype), + } + self.squeeze_axex = [0, 1] + self.target_shape = () + + if __name__ == "__main__": unittest.main()