Skip to content

Commit

Permalink
[0D-Tensor] CINN supports unsqueeze, delete hack in Paddle's pass (Pa…
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 authored and cqulilujia committed Jul 24, 2023
1 parent 2eba819 commit ebb12cf
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 28 deletions.
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,8 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims(
std::vector<std::vector<int>> InferShapeForExpandDims(
const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty())
<< "The input's shape size is 0! Please check again.";
CHECK(!inputs_shape.empty())
<< "At least 1 input tensor for expand_dims operator.";

CHECK_EQ(inputs_shape.size(), 1U);
const std::vector<int> &axes =
Expand Down
26 changes: 0 additions & 26 deletions paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,6 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
"assign_value",
"gaussian_random",
"set_value"};
// NOTE: Hack squeeze2 0D-Tensor input
// If squeeze2 inputs 0D-Tensor and axes, The 0D-Tensor's shape will convert
// to 1D-Tensor, which could lead error. We hack squeeze2's axes attribute to
// resolve this. Change 0D-Tensor input to 1D-Tensor input and then make
// axes->axes[: -1]
for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->Type() == "unsqueeze2") {
if (n->Op()->HasAttr("axes")) {
auto axes =
PADDLE_GET_CONST(std::vector<int32_t>, n->Op()->GetAttr("axes"));
for (const ir::Node* var : n->inputs) {
if (var->Var() &&
var->Var()->GetType() == proto::VarType::LOD_TENSOR) {
std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
axes.pop_back();
n->Op()->SetAttr("axes", axes);
VLOG(4) << "unsqueeze2 axes dims is full, fix dim -> dim[:-1] to "
"avoid 0D-Tensor input error";
}
}
}
}
}
}

// CINN ops in this white list support 0D-Tensor, wait-list = {"remainder"}
const std::unordered_set<std::string> white_op_list{"elementwise_add",
"elementwise_sub",
Expand Down
54 changes: 54 additions & 0 deletions test/cinn/ops/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,60 @@ def test_check_results(self):
self.check_outputs_and_grads()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestExpandDimsOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()

def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.unsqueeze_dim = [0]
self.target_shape = (1,)

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.unsqueeze(x, self.unsqueeze_dim)

self.paddle_outputs = [out]

def build_cinn_program(self, target):
builder = NetBuilder("unsqueeze_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.expand_dims(x, self.unsqueeze_dim)

prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])

self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)

def test_check_results(self):
self.check_outputs_and_grads()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestExpandDimsOp2D(TestExpandDimsOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.unsqueeze_dim = [0, 1]
self.target_shape = (
1,
1,
)


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
Expand Down

0 comments on commit ebb12cf

Please sign in to comment.