Skip to content

Commit

Permalink
Symbolic shape inference support for pd_op.split and builtin.split (P…
Browse files Browse the repository at this point in the history
…addlePaddle#62394)

* WIP: builtin.split op infer sym shape

* bug fix

* Update paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc

Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com>

* Update paddle/fluid/pir/dialect/operator/ir/op_dialect.cc

Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com>

* Update paddle/fluid/pir/dialect/operator/ir/op_dialect.cc

Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com>

* pd_op.split followed by builtin.split

* pd_op.split infer sym shape bugfix and unittest; fix op infer sym error outputs

* recover SplitWithNumOpInferSymbolicShape Unimplemented exception raising

* code refinement

* Rewrite PADDLE_ENFORCE

* remove incorrect comments

* Rewrite PADDLE_ENFORCE

* Rewrite PADDLE_ENFORCE

---------

Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com>
  • Loading branch information
2 people authored and hitywt committed Mar 11, 2024
1 parent 5089edc commit ce85415
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -958,8 +958,98 @@ bool ExpandAsOpInferSymbolicShape(

bool SplitOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
// input
const auto &x_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));
PADDLE_ENFORCE_EQ(x_shape_or_data.data().has_value(),
false,
phi::errors::InvalidArgument(
"InferSymbolicShape of SplitOp only support input with "
"value now."));
const auto &x_dims_sym = x_shape_or_data.shape();

// axis
CHECK(op->operand_source(2).defining_op()->isa<paddle::dialect::FullOp>());

int64_t axis = op->operand_source(2)
.defining_op<paddle::dialect::FullOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int64_t>();

// sections
const std::vector<symbol::DimExpr> &sections_sym = [&] {
const auto &sections_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(1));
std::vector<symbol::DimExpr> sections_sym;
if (sections_shape_or_data.data().has_value()) {
sections_sym = sections_shape_or_data.data().value();
} else {
sections_sym = sections_shape_or_data.shape();
}
return sections_sym;
}();

// output
const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] {
const auto &GetSum = [&](const auto &dim_exprs, const auto &Filter) {
symbol::DimExpr sum{0};
for (const auto &dim_expr : dim_exprs) {
if (Filter(dim_expr)) {
sum = sum + dim_expr;
}
}
return sum;
};
const auto &All = [&](const auto &dim_exprs, const auto &Cond) {
for (const auto &dim_expr : dim_exprs) {
if (!Cond(dim_expr)) {
return false;
}
}
return true;
};
const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) {
if (dim_expr.isa<int64_t>()) {
return dim_expr.dyn_cast<int64_t>() != static_cast<int64_t>(-1);
}
return true;
};
const auto &sum_exclude_minus_one = GetSum(sections_sym, IsNotMinusOne);

const bool &all_sections_sym_not_minus_one =
All(sections_sym, IsNotMinusOne);
if (all_sections_sym_not_minus_one) {
shape_analysis->CreateDimExprBuilder().CstrEq(x_dims_sym[axis],
sum_exclude_minus_one);
}

symbol::TensorListShapeOrDataDimExprs shape_data_list;
std::vector<symbol::DimExpr> output_dims_sym = x_dims_sym;
if (!all_sections_sym_not_minus_one && sections_sym.size() == 1) {
VLOG(3) << "[SplitOp]-1 is the only split section. The output shape is "
"identical to the input shape.";
shape_data_list.push_back(
symbol::TensorShapeOrDataDimExprs(output_dims_sym));
return shape_data_list;
}
for (uint32_t idx = 0; idx < sections_sym.size(); idx++) {
const auto &section_sym = sections_sym[idx];
output_dims_sym[axis] = IsNotMinusOne(section_sym)
? section_sym
: x_dims_sym[axis] - sum_exclude_minus_one;

shape_data_list.push_back(
symbol::TensorShapeOrDataDimExprs(output_dims_sym));
}
return shape_data_list;
}();

shape_analysis->SetShapeOrDataForValue(
op->result(0), symbol::ShapeOrDataDimExprs{output_shape_data_list});

return true;
}

Expand Down
31 changes: 31 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,32 @@ struct ShadowOutputOpInferSymbolicShapeInterfaceModel
: InferSymbolicShapeInterface::Concept(InferSymbolicShape) {}
};

struct SplitOpInferSymbolicShapeInterfaceModel
: public InferSymbolicShapeInterface::Concept {
static inline bool InferSymbolicShape(
pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) {
const auto& shape_data_list =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0))
.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();

for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) {
PADDLE_ENFORCE_EQ(
shape_data_list[rst_idx].data().has_value(),
false,
paddle::platform::errors::InvalidArgument(
"Currently InferSymbolicShape of SplitOp only support "
"input without value."));
shape_analysis->SetShapeOrDataForValue(
op->result(rst_idx),
symbol::ShapeOrDataDimExprs{shape_data_list[rst_idx]});
}
return true;
}

SplitOpInferSymbolicShapeInterfaceModel()
: InferSymbolicShapeInterface::Concept(InferSymbolicShape) {}
};

struct YieldOpInferSymbolicShapeInterfaceModel
: public InferSymbolicShapeInterface::Concept {
static inline bool InferSymbolicShape(
Expand Down Expand Up @@ -196,6 +222,11 @@ OperatorDialect::OperatorDialect(pir::IrContext* ctx)
InferSymbolicShapeInterface,
ShadowOutputOpInferSymbolicShapeInterfaceModel>()));

info = ctx->GetRegisteredOpInfo(pir::SplitOp::name());
info.AttachInterface(std::move(
pir::InterfaceValue::Get<InferSymbolicShapeInterface,
SplitOpInferSymbolicShapeInterfaceModel>()));

info = ctx->GetRegisteredOpInfo(pir::YieldOp::name());
info.AttachInterface(std::move(
pir::InterfaceValue::Get<InferSymbolicShapeInterface,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@
kernel :
func : split
backward : split_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : split_with_num
args : (Tensor x, int num, Scalar(int) axis)
Expand Down
81 changes: 77 additions & 4 deletions test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

return True
Expand Down Expand Up @@ -403,7 +403,7 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

return True
Expand Down Expand Up @@ -453,7 +453,7 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

return True
Expand Down Expand Up @@ -512,11 +512,84 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

return True


class SplitNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
out = paddle.split(x, [-1], axis=1)
out = paddle.split(x, [1, 2, -1], axis=1)
out = paddle.split(x, [1, -1], axis=1)
out = paddle.split(x, [1, 2, 3], axis=1)
out = paddle.split(x, [1, 2, x.shape[1]], axis=1)

out = x.split([-1], axis=1)
out = x.split([1, 2, -1], axis=1)
out = x.split([1, -1], axis=1)
out = x.split([1, 2, 3], axis=1)
out = x.split([1, 2, x.shape[1]], axis=1)

return out


class TestSplitOpInferSymbolicShape(TestBase):
def prepare_data(self):
self.cases = [np.random.rand(4, 6, 5)]

self.expected = [
[
'shape[S0, S1, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]',
'shape[S0, S1, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]',
]
]

def test_eval_symbolic(self):
net = SplitNet()

for i in range(len(self.cases)):
x = self.cases[i]
x_spec = InputSpec(
shape=[None for index in range(len(x.shape))], dtype='float32'
)

input_spec = [x_spec]
net = apply_to_static(net, False, input_spec)
net.eval()

# check the infer result
sym_shape_str_list = get_sym_shape_str_for_op(
net, input_spec, 'pd_op.split'
)
np.testing.assert_equal(
len(sym_shape_str_list), len(self.expected[i])
)
for j in range(len(sym_shape_str_list)):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

# TODO(fty1777): Add builtin.split op infer symbolic shape test
# Not added because attribute `sym_shape_str` does not support multi-output op now.
# See also: paddle/fluid/pir/transforms/shape_optimization_pass.cc:144.

return True


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

return True
Expand Down

0 comments on commit ce85415

Please sign in to comment.