Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0D-Tensor] CINN supports topk, sort, argsort, fix infershape #55510

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions paddle/cinn/hlir/op/contrib/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,18 @@ std::vector<std::vector<int>> InferShapeForSort(
break;
}
}
CHECK_GT(inputs_shape[0].size(), axis)
<< "The input's dim should be greater than axis! ";
if (inputs_shape[0].empty()) {
// 0D Tensor
CHECK(axis == 0 || axis == -1)
<< "Axis must be 0 or -1 if input tensor is 0-dim";
} else {
if (axis < 0) {
axis += inputs_shape[0].size();
}
CHECK_GT(inputs_shape[0].size(), axis)
<< "The input's dim should be greater than axis! ";
}

std::vector<std::vector<int>> res{inputs_shape[0]};
return res;
}
Expand All @@ -352,11 +362,17 @@ std::vector<std::vector<int>> InferShapeForArgSort(
break;
}
}
if (axis < 0) {
axis += inputs_shape[0].size();
if (inputs_shape[0].empty()) {
// 0D Tensor
CHECK(axis == 0 || axis == -1)
<< "Axis must be 0 or -1 if input tensor is 0-dim";
} else {
if (axis < 0) {
axis += inputs_shape[0].size();
}
CHECK_GT(inputs_shape[0].size(), axis)
<< "The input's dim should be greater than axis! ";
}
CHECK_GT(inputs_shape[0].size(), axis)
<< "The input's dim should be greater than axis! ";
std::vector<std::vector<int>> res{inputs_shape[0], inputs_shape[0]};

return res;
Expand All @@ -381,12 +397,19 @@ std::vector<std::vector<int>> InferShapeForTopK(
auto axis_it = attrs.find("axis");
CHECK(axis_it != attrs.end()) << "The attr axis of topk does not exist.";
int axis = absl::get<int>(axis_it->second);
if (axis < 0) {
axis += res[0].size();

if (inputs_shape[0].empty()) {
// 0D Tensor
CHECK(axis == 0 || axis == -1)
<< "Axis must be 0 or -1 if input tensor is 0-dim";
} else {
if (axis < 0) {
axis += inputs_shape[0].size();
}
CHECK_GE(axis, 0);
CHECK_LT(axis, res[0].size());
res[0][axis] = std::min(res[0][axis], k);
}
CHECK_GE(axis, 0);
CHECK_LT(axis, res[0].size());
res[0][axis] = std::min(res[0][axis], k);
return {res[0], res[0]};
}

Expand Down
147 changes: 147 additions & 0 deletions test/cinn/ops/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,153 @@ def test_check_results(self):
self.check_outputs_and_grads()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestArgsortOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()

def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = -1
self.target_shape = ()

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.argsort(x, axis=self.axis)

self.paddle_outputs = [out]

def build_cinn_program(self, target):
builder = NetBuilder("argsort_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.argsort(x, self.axis, True)

prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], out)

self.cinn_outputs = np.array([res[0]]).astype("int64")
self.assertEqual(res[0].shape, self.target_shape)

def test_check_results(self):
self.check_outputs_and_grads()


class TestArgsortOp2(TestArgsortOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = 0
self.target_shape = ()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestSortOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()

def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = -1
self.target_shape = ()

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.sort(x, axis=self.axis)

self.paddle_outputs = [out]

def build_cinn_program(self, target):
builder = NetBuilder("sort_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.sort(x, self.axis, True)

prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])

self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)

def test_check_results(self):
self.check_outputs_and_grads()


class TestSortOp2(TestSortOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = 0
self.target_shape = ()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestTopkOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()

def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = -1
self.target_shape = ()

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out, indices = paddle.topk(x, k=1, axis=self.axis)

self.paddle_outputs = [out, indices]

def build_cinn_program(self, target):
builder = NetBuilder("topk_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.top_k(x, 1, self.axis, True)

prog = builder.build()
res = self.get_cinn_output(
prog, target, [x], [self.inputs["x"]], [out[0], out[1]]
)

self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
self.assertEqual(res[1].shape, self.target_shape)

def test_check_results(self):
self.check_outputs_and_grads()


class TestTopkOp2(TestTopkOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = 0
self.target_shape = ()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
Expand Down