diff --git a/paddle/cinn/hlir/op/contrib/sort.cc b/paddle/cinn/hlir/op/contrib/sort.cc index 770175683d0288..0fe527811f072e 100644 --- a/paddle/cinn/hlir/op/contrib/sort.cc +++ b/paddle/cinn/hlir/op/contrib/sort.cc @@ -326,8 +326,18 @@ std::vector> InferShapeForSort( break; } } - CHECK_GT(inputs_shape[0].size(), axis) - << "The input's dim should be greater than axis! "; + if (inputs_shape[0].empty()) { + // 0D Tensor + CHECK(axis == 0 || axis == -1) + << "Axis must be 0 or -1 if input tensor is 0-dim"; + } else { + if (axis < 0) { + axis += inputs_shape[0].size(); + } + CHECK_GT(inputs_shape[0].size(), axis) + << "The input's dim should be greater than axis! "; + } + std::vector> res{inputs_shape[0]}; return res; } @@ -352,11 +362,17 @@ std::vector> InferShapeForArgSort( break; } } - if (axis < 0) { - axis += inputs_shape[0].size(); + if (inputs_shape[0].empty()) { + // 0D Tensor + CHECK(axis == 0 || axis == -1) + << "Axis must be 0 or -1 if input tensor is 0-dim"; + } else { + if (axis < 0) { + axis += inputs_shape[0].size(); + } + CHECK_GT(inputs_shape[0].size(), axis) + << "The input's dim should be greater than axis! "; } - CHECK_GT(inputs_shape[0].size(), axis) - << "The input's dim should be greater than axis! "; std::vector> res{inputs_shape[0], inputs_shape[0]}; return res; @@ -381,12 +397,19 @@ std::vector> InferShapeForTopK( auto axis_it = attrs.find("axis"); CHECK(axis_it != attrs.end()) << "The attr axis of topk does not exist."; int axis = absl::get(axis_it->second); - if (axis < 0) { - axis += res[0].size(); + + if (inputs_shape[0].empty()) { + // 0D Tensor + CHECK(axis == 0 || axis == -1) + << "Axis must be 0 or -1 if input tensor is 0-dim"; + } else { + if (axis < 0) { + axis += inputs_shape[0].size(); + } + CHECK_GE(axis, 0); + CHECK_LT(axis, res[0].size()); + res[0][axis] = std::min(res[0][axis], k); } - CHECK_GE(axis, 0); - CHECK_LT(axis, res[0].size()); - res[0][axis] = std::min(res[0][axis], k); return {res[0], res[0]}; } diff --git a/test/cinn/ops/test_zero_dim_tensor.py b/test/cinn/ops/test_zero_dim_tensor.py index 3f73809cfcfc5d..1138820f79b3f3 100644 --- a/test/cinn/ops/test_zero_dim_tensor.py +++ b/test/cinn/ops/test_zero_dim_tensor.py @@ -788,6 +788,153 @@ def test_check_results(self): self.check_outputs_and_grads() +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestArgsortOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.axis = -1 + self.target_shape = () + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + out = paddle.argsort(x, axis=self.axis) + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("argsort_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + out = builder.argsort(x, self.axis, True) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], out) + + self.cinn_outputs = np.array([res[0]]).astype("int64") + self.assertEqual(res[0].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestArgsortOp2(TestArgsortOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.axis = 0 + self.target_shape = () + + +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestSortOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.axis = -1 + self.target_shape = () + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + out = paddle.sort(x, axis=self.axis) + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("sort_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + out = builder.sort(x, self.axis, True) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out]) + + self.cinn_outputs = res + self.assertEqual(res[0].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestSortOp2(TestSortOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.axis = 0 + self.target_shape = () + + +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestTopkOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.axis = -1 + self.target_shape = () + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + out, indices = paddle.topk(x, k=1, axis=self.axis) + + self.paddle_outputs = [out, indices] + + def build_cinn_program(self, target): + builder = NetBuilder("topk_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + out = builder.top_k(x, 1, self.axis, True) + + prog = builder.build() + res = self.get_cinn_output( + prog, target, [x], [self.inputs["x"]], [out[0], out[1]] + ) + + self.cinn_outputs = res + self.assertEqual(res[0].shape, self.target_shape) + self.assertEqual(res[1].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestTopkOp2(TestTopkOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.axis = 0 + self.target_shape = () + + @OpTestTool.skip_if( not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." )