Skip to content

Commit

Permalink
[0D-Tensor] CINN supports squeeze, fix infershape and GetPositiveAxes (
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 authored and wz1qqx committed Jul 31, 2023
1 parent 8214bc6 commit 9fd0db1
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
4 changes: 0 additions & 4 deletions paddle/cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,6 @@ std::vector<std::vector<int>> InferShapeForSqueeze(

VLOG(4) << "The output calculated in Squeeze: "
<< cinn::utils::Join(output_shape, ", ");

if (output_shape.size() == 0) {
output_shape.push_back(1);
}
return {output_shape};
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/utils/functional.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ std::vector<int> GetPositiveAxes(const std::vector<int>& axes, int rank) {
std::vector<int> new_axes(axes.size());
for (int i = 0; i < axes.size(); ++i) {
int axis = axes[i] + (axes[i] < 0 ? rank : 0);
CHECK(axis >= 0 && axis < rank)
CHECK(axis >= 0 && (rank == 0 || axis < rank))
<< "The axis should in [" << -rank << ", " << rank << "), but axes["
<< i << "]=" << axes[i] << " not.";
new_axes[i] = axis;
Expand Down
57 changes: 57 additions & 0 deletions test/cinn/ops/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,5 +713,62 @@ def test_check_results(self):
self.check_outputs_and_grads()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestSqueezeOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()

def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.squeeze_axex = [0]
self.target_shape = ()

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.squeeze(x, axis=self.squeeze_axex)

self.paddle_outputs = [out]

def build_cinn_program(self, target):
builder = NetBuilder("squeeze_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.squeeze(x, self.squeeze_axex)

prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])

self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)

def test_check_results(self):
self.check_outputs_and_grads()


class TestSqueezeOp1D(TestSqueezeOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, [1]).astype(self.dtype),
}
self.squeeze_axex = []
self.target_shape = ()


class TestSqueezeOp2D(TestSqueezeOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, [1, 1]).astype(self.dtype),
}
self.squeeze_axex = [0, 1]
self.target_shape = ()


if __name__ == "__main__":
unittest.main()

0 comments on commit 9fd0db1

Please sign in to comment.