From ce8541520adc4474d9e8e644da397c282845f12b Mon Sep 17 00:00:00 2001 From: Tianyu Feng <45195157+fty1777@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:51:45 +0800 Subject: [PATCH] Symbolic shape inference support for pd_op.split and builtin.split (#62394) * WIP: builtin.split op infer sym shape * bug fix * Update paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> * Update paddle/fluid/pir/dialect/operator/ir/op_dialect.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> * Update paddle/fluid/pir/dialect/operator/ir/op_dialect.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> * pd_op.split followed by builtin.split * pd_op.split infer sym shape bugfix and unittest; fix op infer sym error outputs * recover SplitWithNumOpInferSymbolicShape Unimplemented exception raising * code refinement * Rewrite PADDLE_ENFORCE * remove incorrect comments * Rewrite PADDLE_ENFORCE * Rewrite PADDLE_ENFORCE --------- Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> --- .../paddle_op_infer_sym.cc | 94 ++++++++++++++++++- .../pir/dialect/operator/ir/op_dialect.cc | 31 ++++++ paddle/phi/api/yaml/legacy_ops.yaml | 1 + .../cinn/symbolic/test_op_infer_sym_shape.py | 81 +++++++++++++++- .../symbolic/test_unary_op_infer_sym_shape.py | 2 +- 5 files changed, 202 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index ee4f2d406b3a29..0d9f6ce5a036cb 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -958,8 +958,98 @@ bool ExpandAsOpInferSymbolicShape( bool SplitOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + // input + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + PADDLE_ENFORCE_EQ(x_shape_or_data.data().has_value(), + false, + phi::errors::InvalidArgument( + "InferSymbolicShape of SplitOp only support input with " + "value now.")); + const auto &x_dims_sym = x_shape_or_data.shape(); + + // axis + CHECK(op->operand_source(2).defining_op()->isa()); + + int64_t axis = op->operand_source(2) + .defining_op() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + + // sections + const std::vector §ions_sym = [&] { + const auto §ions_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + std::vector sections_sym; + if (sections_shape_or_data.data().has_value()) { + sections_sym = sections_shape_or_data.data().value(); + } else { + sections_sym = sections_shape_or_data.shape(); + } + return sections_sym; + }(); + + // output + const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] { + const auto &GetSum = [&](const auto &dim_exprs, const auto &Filter) { + symbol::DimExpr sum{0}; + for (const auto &dim_expr : dim_exprs) { + if (Filter(dim_expr)) { + sum = sum + dim_expr; + } + } + return sum; + }; + const auto &All = [&](const auto &dim_exprs, const auto &Cond) { + for (const auto &dim_expr : dim_exprs) { + if (!Cond(dim_expr)) { + return false; + } + } + return true; + }; + const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) { + if (dim_expr.isa()) { + return dim_expr.dyn_cast() != static_cast(-1); + } + return true; + }; + const auto &sum_exclude_minus_one = GetSum(sections_sym, IsNotMinusOne); + + const bool &all_sections_sym_not_minus_one = + All(sections_sym, IsNotMinusOne); + if (all_sections_sym_not_minus_one) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims_sym[axis], + sum_exclude_minus_one); + } + + symbol::TensorListShapeOrDataDimExprs shape_data_list; + std::vector output_dims_sym = x_dims_sym; + if (!all_sections_sym_not_minus_one && sections_sym.size() == 1) { + VLOG(3) << "[SplitOp]-1 is the only split section. The output shape is " + "identical to the input shape."; + shape_data_list.push_back( + symbol::TensorShapeOrDataDimExprs(output_dims_sym)); + return shape_data_list; + } + for (uint32_t idx = 0; idx < sections_sym.size(); idx++) { + const auto §ion_sym = sections_sym[idx]; + output_dims_sym[axis] = IsNotMinusOne(section_sym) + ? section_sym + : x_dims_sym[axis] - sum_exclude_minus_one; + + shape_data_list.push_back( + symbol::TensorShapeOrDataDimExprs(output_dims_sym)); + } + return shape_data_list; + }(); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), symbol::ShapeOrDataDimExprs{output_shape_data_list}); + return true; } diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 7262589c7ad3ab..1364c1e1e0c777 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -159,6 +159,32 @@ struct ShadowOutputOpInferSymbolicShapeInterfaceModel : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} }; +struct SplitOpInferSymbolicShapeInterfaceModel + : public InferSymbolicShapeInterface::Concept { + static inline bool InferSymbolicShape( + pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + const auto& shape_data_list = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)) + .dyn_cast(); + + for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { + PADDLE_ENFORCE_EQ( + shape_data_list[rst_idx].data().has_value(), + false, + paddle::platform::errors::InvalidArgument( + "Currently InferSymbolicShape of SplitOp only support " + "input without value.")); + shape_analysis->SetShapeOrDataForValue( + op->result(rst_idx), + symbol::ShapeOrDataDimExprs{shape_data_list[rst_idx]}); + } + return true; + } + + SplitOpInferSymbolicShapeInterfaceModel() + : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} +}; + struct YieldOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( @@ -196,6 +222,11 @@ OperatorDialect::OperatorDialect(pir::IrContext* ctx) InferSymbolicShapeInterface, ShadowOutputOpInferSymbolicShapeInterfaceModel>())); + info = ctx->GetRegisteredOpInfo(pir::SplitOp::name()); + info.AttachInterface(std::move( + pir::InterfaceValue::Get())); + info = ctx->GetRegisteredOpInfo(pir::YieldOp::name()); info.AttachInterface(std::move( pir::InterfaceValue::Get