From b80ece97131961b56bdb034d1ae42f65a5810dda Mon Sep 17 00:00:00 2001 From: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Date: Tue, 9 Jan 2024 13:43:35 +0800 Subject: [PATCH] [PIR] Reify InferSymbolicShapeInterface (#60438) * Reify InferSymbolicShapeInterface --- paddle/cinn/hlir/framework/pir/utils.cc | 3 + paddle/fluid/inference/paddle_inference.map | 1 - paddle/fluid/inference/utils/CMakeLists.txt | 2 +- .../interface/infer_symbolic_shape.cc | 268 +++++++++--------- .../operator/interface/infer_symbolic_shape.h | 6 + .../pir/dialect/operator/ir/op_dialect.cc | 50 +++- .../pir/transforms/shape_optimization_pass.cc | 113 ++------ paddle/fluid/pybind/pir.cc | 4 +- paddle/phi/api/yaml/ops.yaml | 2 + .../pir/dialect/shape/ir/shape_attribute.cc | 30 ++ paddle/pir/dialect/shape/ir/shape_attribute.h | 37 +++ .../shape/ir/shape_attribute_storage.h | 70 +++++ paddle/pir/dialect/shape/ir/shape_dialect.cc | 32 +++ paddle/pir/dialect/shape/ir/shape_dialect.h | 4 + paddle/pir/dialect/shape/utils/dim_expr.h | 20 +- paddle/pir/dialect/shape/utils/shape_utils.h | 3 + 16 files changed, 411 insertions(+), 234 deletions(-) create mode 100644 paddle/pir/dialect/shape/ir/shape_attribute.cc create mode 100644 paddle/pir/dialect/shape/ir/shape_attribute.h create mode 100644 paddle/pir/dialect/shape/ir/shape_attribute_storage.h diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc index 8833ac496e32c..a0a6f5f15614b 100644 --- a/paddle/cinn/hlir/framework/pir/utils.cc +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -25,6 +25,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" namespace cinn { namespace hlir { @@ -146,6 +147,8 @@ utils::Attribute CompatibleInfo::ConvertAttribute( } else if (src_attr.isa()) { auto dtype = src_attr.dyn_cast().data(); dst_attr = phi::DataTypeToString(dtype); + } else if (src_attr.isa<::pir::shape::SymbolAttribute>()) { + auto dst_attr = src_attr.dyn_cast<::pir::shape::SymbolAttribute>().data(); } else if (src_attr.isa<::pir::ArrayAttribute>()) { auto attr_vec = src_attr.dyn_cast<::pir::ArrayAttribute>().AsVector(); if (attr_vec.size() > 0) { diff --git a/paddle/fluid/inference/paddle_inference.map b/paddle/fluid/inference/paddle_inference.map index 29f131be85e1a..01a989cc568bc 100644 --- a/paddle/fluid/inference/paddle_inference.map +++ b/paddle/fluid/inference/paddle_inference.map @@ -82,7 +82,6 @@ *Pass*; *profile*; *phi*; - *pir*; PD_*; *cinn*; local: diff --git a/paddle/fluid/inference/utils/CMakeLists.txt b/paddle/fluid/inference/utils/CMakeLists.txt index 3dbc06bfc11b7..976cb2dccc8c1 100644 --- a/paddle/fluid/inference/utils/CMakeLists.txt +++ b/paddle/fluid/inference/utils/CMakeLists.txt @@ -13,7 +13,7 @@ cc_library( DEPS proto_desc enforce common) cc_library(table_printer SRCS table_printer.cc) -paddle_test(test_table_printer SRCS table_printer_tester.cc) +paddle_test(test_table_printer SRCS table_printer_tester.cc DEPS pir) proto_library(shape_range_info_proto SRCS shape_range_info.proto) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 99ab424279a7d..cd324b5f05c69 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -13,9 +13,10 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_type.h" -#include "paddle/pir/dialect/shape/ir/shape_op.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" namespace paddle::dialect { @@ -25,29 +26,20 @@ bool InferSymbolicShapeInterface::InferSymbolicShape( } } // namespace paddle::dialect -namespace paddle::dialect { - namespace { -bool InferSymbolicShapeAllEqualUnary( +bool SameOperandsAndResultShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); - pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); - shape_analysis->value_id_to_shapeordata_[res_id] = - shape_analysis->value_id_to_shapeordata_[operand_source_id]; - return true; -} -bool InferSymbolicShapeAllEqualBinary( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source]; + + op->set_attribute("symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), + operand_shape_or_data)); pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); - shape_analysis->value_id_to_shapeordata_[res_id] = - shape_analysis->value_id_to_shapeordata_[operand_source_id]; + shape_analysis->value_to_shape_or_data_[res] = operand_shape_or_data; return true; } @@ -91,14 +83,46 @@ bool InferSymbolicShapeElementWiseBinary( } // namespace +namespace paddle::dialect { bool AbsOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool Abs_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); +} + +bool DataOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + auto attributes = op->attributes(); + pir::Attribute attr = attributes["shape"]; + std::vector dims = + attr.dyn_cast().data().GetData(); + + std::vector sym_dims; + for (auto dim : dims) { + symbol::DimExpr dim_expr; + if (dim == -1) { + symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = symbolic_dim_expr; + } else { + symbol::DimExpr numeric_dim_expr(dim); + dim_expr = numeric_dim_expr; + } + sym_dims.push_back(dim_expr); + } + + symbol::ShapeOrDataDimExprs shape_data{sym_dims}; + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + + pir::OpResult res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; + + return true; } bool AddOpInferSymbolicShape(pir::Operation *op, @@ -113,61 +137,50 @@ bool Add_OpInferSymbolicShape(pir::Operation *op, bool CastOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool Cast_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool ExpOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool Exp_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool SubtractOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualBinary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool Subtract_OpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualBinary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool ShapeOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); - std::vector dims = - common::vectorize(res.type().dyn_cast().dims()); + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source]; - std::vector shapes; - for (int64_t dim : dims) { - symbol::DimExpr dim_expr; - if (dim == -1) { - symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); - dim_expr = res_dim_expr; - } else { - symbol::DimExpr res_dim_expr(dim); - dim_expr = res_dim_expr; - } - shapes.push_back(dim_expr); - } + symbol::ShapeOrDataDimExprs extend_shape_or_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData( + operand_shape_or_data); - symbol::ShapeOrDataDimExprs shape_data{ - shapes, - shape_analysis->value_id_to_shapeordata_[operand_source_id].shape()}; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + shape_analysis->value_to_shape_or_data_[res] = extend_shape_or_data; + op->set_attribute("symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), + extend_shape_or_data)); return true; } @@ -179,27 +192,53 @@ bool ShapeSrOpInferSymbolicShape( bool StackOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); - pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source]; - symbol::ShapeOrDataDimExprs shape_data; - shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id]; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + std::vector out_dims; + if (operand_shape_or_data.data().has_value()) { + out_dims = operand_shape_or_data.data().value(); + } + // else : pir::VectorType x = + // operand_source.type().dyn_cast(); + // TODO(zhangbopd): else branch is not implemented yet. + + symbol::ShapeOrDataDimExprs shape_data{out_dims}; + if (operand_shape_or_data.data().has_value()) { + shape_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(shape_data); + } + + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + pir::OpResult res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; return true; } bool ReshapeOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source_1 = op->operand_source(1); - std::string operand_source_1_id = pir::GetValueId(&operand_source_1); - pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); + pir::Value operand_source_shape = op->operand_source(1); - symbol::ShapeOrDataDimExprs shape_data{ - *(shape_analysis->value_id_to_shapeordata_[operand_source_1_id].data())}; + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source_shape]; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + std::vector out_dims; + if (operand_shape_or_data.data().has_value()) { + out_dims = operand_shape_or_data.data().value(); + } + + symbol::ShapeOrDataDimExprs shape_data{out_dims}; + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + + pir::OpResult res0 = op->result(0); + pir::OpResult res1 = op->result(1); + shape_analysis->value_to_shape_or_data_[res0] = shape_data; + shape_analysis->value_to_shape_or_data_[res1] = + shape_analysis->value_to_shape_or_data_[operand_source_shape]; return true; } @@ -208,51 +247,33 @@ bool Reshape_OpInferSymbolicShape( return ReshapeOpInferSymbolicShape(op, shape_analysis); } -bool SliceOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); - pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); - - std::vector dims = - common::vectorize(res.type().dyn_cast().dims()); +bool FullIntArrayOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + auto attributes = op->attributes(); + pir::Attribute attr = attributes["value"]; + const auto &vec = attr.dyn_cast().AsVector(); - std::vector shapes; - for (int64_t dim : dims) { - symbol::DimExpr dim_expr; - if (dim == -1) { - symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); - dim_expr = res_dim_expr; - } else { - symbol::DimExpr res_dim_expr(dim); - dim_expr = res_dim_expr; - } - shapes.push_back(dim_expr); + std::vector data; + for (auto item : vec) { + int64_t i = item.dyn_cast().data(); + data.push_back(symbol::DimExpr(i)); } - auto operand_source_1 = op->operand_source(1); - std::string operand_source_1_id = pir::GetValueId(&operand_source_1); - auto starts_array = - (shape_analysis->value_id_to_shapeordata_[operand_source_1_id]).data(); - auto start = starts_array->at(0).Get(); - - auto operand_source_2 = op->operand_source(2); - std::string operand_source_2_id = pir::GetValueId(&operand_source_2); - auto ends_array = - (shape_analysis->value_id_to_shapeordata_[operand_source_2_id]).data(); - auto end = ends_array->at(0).Get(); + symbol::ShapeOrDataDimExprs shape_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(data); - std::vector data; - auto source_data = - (shape_analysis->value_id_to_shapeordata_[operand_source_id]).data(); + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - for (int i = start; i < end; i++) { - data.emplace_back(source_data->at(i)); - } + pir::OpResult res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; + return true; +} - symbol::ShapeOrDataDimExprs shape_data{shapes, data}; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + // TODO(zhangbopd): Not implemented yet. return true; } @@ -261,47 +282,34 @@ namespace cinn::dialect { bool SliceOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + // TODO(zhangbopd): Not implemented yet, different from the one in paddle + // dialect. pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); - pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); - - std::vector dims = - common::vectorize(res.type().dyn_cast().dims()); - - std::vector shapes; - for (int64_t dim : dims) { - symbol::DimExpr dim_expr; - if (dim == -1) { - symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); - dim_expr = res_dim_expr; - } else { - symbol::DimExpr res_dim_expr(dim); - dim_expr = res_dim_expr; - } - shapes.push_back(dim_expr); - } - + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source]; pir::AttributeMap attributes = op->attributes(); - auto attr_starts = + std::vector attr_starts = attributes["starts"].dyn_cast().AsVector(); - auto start = attr_starts[0].dyn_cast().data(); - auto attr_ends = - attributes["ends"].dyn_cast().AsVector(); - auto end = attr_ends[0].dyn_cast().data(); + int64_t start = attr_starts[0].dyn_cast().data(); - std::vector data; - auto source_data = - (shape_analysis->value_id_to_shapeordata_[operand_source_id]).data(); + std::vector out_dims; + if (operand_shape_or_data.data().has_value()) { + out_dims.push_back(operand_shape_or_data.data().value()[start]); + } - for (int i = start; i < end; i++) { - data.emplace_back(source_data->at(i)); + symbol::ShapeOrDataDimExprs shape_data{out_dims}; + if (operand_shape_or_data.data().has_value()) { + shape_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(shape_data); } + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - symbol::ShapeOrDataDimExprs shape_data{shapes, data}; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + pir::OpResult res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; return true; } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index 4e7c70fd386ef..ff1891e49c3c6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -68,6 +68,9 @@ bool AbsOpInferSymbolicShape(pir::Operation *op, bool Abs_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); +bool DataOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + bool AddOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); @@ -107,6 +110,9 @@ bool ReshapeOpInferSymbolicShape( bool Reshape_OpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); +bool FullIntArrayOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + bool SliceOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 969edf32204bf..6e54933f83d11 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -26,6 +26,7 @@ #include "paddle/pir/core/utils.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/dialect/control_flow/ir/cf_op.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" namespace paddle { namespace dialect { @@ -44,25 +45,44 @@ struct CombineOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - std::vector shapes; - std::vector data; - - for (auto operand_source : op->operands_source()) { - std::string operand_source_id = pir::GetValueId(&operand_source); - auto source_data_p = - shape_analysis->value_id_to_shapeordata_[operand_source_id].data(); - auto source_shape_vec = - source_data_p.value_or(std::vector{}); - for (size_t i = 0; i < source_shape_vec.size(); i++) { - data.emplace_back(source_shape_vec.at(i)); + std::vector out_dims; + + // Currently for all operand : type.dims == 1u + for (size_t i = 0; i < op->num_operands(); ++i) { + auto type = + op->operand(i).type().dyn_cast(); + IR_ENFORCE(type, "Currently only support DenseTensorType."); + IR_ENFORCE(type.dims().size() == 0u, + "Currently CombineOp only support 0-d DenseTensorType for " + "InferSymbolicShape. But the dims of the %d-th " + "DenseTensorType is %d.", + i, + type.dims().size()); + } + + auto operand_source_1st_data = + shape_analysis->value_to_shape_or_data_[op->operand_source(0)].data(); + if (operand_source_1st_data.has_value()) { + for (auto operand_source : op->operands_source()) { + auto source_data = + shape_analysis->value_to_shape_or_data_[operand_source] + .data() + .value(); + out_dims.push_back(source_data[0]); } } - auto res = op->result(0); - auto res_id = pir::GetValueId(&res); + symbol::ShapeOrDataDimExprs shape_data{out_dims}; + if (operand_source_1st_data.has_value()) { + shape_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(shape_data); + } - symbol::ShapeOrDataDimExprs shape_data{shapes, data}; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + op->set_attribute("symbolic_shape", + pir::shape::SymbolAttribute::get( + pir::IrContext::Instance(), shape_data)); + auto res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; return true; } diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index b42584fcf0953..2bf33603fa7be 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -125,20 +125,6 @@ struct DimOfShapedTypeOpInterfacePattern : public OpRewritePattern { } }; -bool MaterializeShapeComputation(pir::ModuleOp m) { - // if (!InsertTieShapeOnRegion(&(m->region(0)))) return false; - // TODO(zhangbopd): add rewitter pattern for reifyInferShape. - RewritePatternSet patterns(m.ir_context()); - - patterns.Add>( - patterns.ir_context()); - - IR_ENFORCE(ApplyPatternsGreedily(m, std::move(patterns)).first, - "fail to materialize shape computation\n"); - return true; -} - using PassPipelineRunner = std::function; @@ -355,41 +341,38 @@ void PrintProgram(pir::ModuleOp m, std::string mgs) { void DebugPrintOpInfo( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { - VLOG(3) << op->name() << ", num_operands: " << op->num_operands(); for (auto& res : op->results()) { - auto value_id = pir::GetValueId(&res); std::ostringstream print_stream; - print_stream << ">>>> result(" << res.index() << ") 's ID: " << value_id; + print_stream << "result(" << res.index() << ") " + << "ShapeOrData: "; + if (shape_analysis != nullptr) { - auto shape_data = shape_analysis->value_id_to_shapeordata_[value_id]; - print_stream << ", ShapeOrData.shape: ["; - - // for (auto str : shape_data.shape()) { - // int64_t* i = std::get_if(&str); - // std::string* s = std::get_if(&str); - // if (i) { - // print_stream << *i << ", "; - // } else if (s) { - // print_stream << *s << ", "; - // } - // } - for (auto dim : shape_data.shape()) { - print_stream << dim << ", "; + auto shape_data = shape_analysis->value_to_shape_or_data_[res]; + print_stream << "shape: ["; + + for (size_t i = 0; i < shape_data.shape().size(); ++i) { + if (i != shape_data.shape().size() - 1) { + print_stream << symbol::ToString(shape_data.shape()[i]) << ","; + } else { + print_stream << symbol::ToString(shape_data.shape()[i]); + } } - print_stream << "], ShapeOrData.data: ["; + print_stream << "], data: ["; if (shape_data.data().has_value()) { - for (auto str : shape_data.data().value()) { - int64_t* i = std::get_if(&str); - std::string* s = std::get_if(&str); - if (i) { - print_stream << *i << ", "; - } else if (s) { - print_stream << *s << ", "; + for (size_t i = 0; i < shape_data.data().value().size(); ++i) { + if (i != shape_data.data().value().size() - 1) { + print_stream << symbol::ToString(shape_data.data().value()[i]) + << ","; + } else { + print_stream << symbol::ToString(shape_data.data().value()[i]); } } + } else { + print_stream << "nullopt"; } + print_stream << "]\n"; } VLOG(3) << print_stream.str(); @@ -403,51 +386,13 @@ void InferSymExprForAllValues(ModuleOp module_op) { for (uint32_t i = 0; i < module_op->num_regions(); i++) { for (auto& block : module_op->region(i)) { for (auto& op : block) { - if (op.num_operands() == 0) { - for (auto& res : op.results()) { - auto value_id = pir::GetValueId(&res); - - std::vector dims = common::vectorize( - res.type().dyn_cast().dims()); - - std::vector shapes; - for (int64_t dim : dims) { - symbol::DimExpr dim_expr; - if (dim == -1) { - symbol::DimExpr res_dim_expr(shape_analysis.GetNextSymName()); - dim_expr = res_dim_expr; - } else { - symbol::DimExpr res_dim_expr(dim); - dim_expr = res_dim_expr; - } - shapes.push_back(dim_expr); - } - - symbol::ShapeOrDataDimExprs shape_data{shapes}; - shape_analysis.value_id_to_shapeordata_[value_id] = shape_data; - - if (op.name() == "pd_op.full_int_array") { - std::vector data; - auto attributes = op.attributes(); - auto attr = attributes["value"]; - auto arr = attr.dyn_cast(); - const auto& vec = arr.AsVector(); - for (auto item : vec) { - int64_t i = item.dyn_cast().data(); - data.push_back(symbol::DimExpr(i)); - } - shape_analysis.value_id_to_shapeordata_[value_id].SetData(data); - } - } - } else { - auto infer_symbolic_shape_interface = - op.dyn_cast(); - if (infer_symbolic_shape_interface) { - PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape( - &shape_analysis)); - } + auto infer_symbolic_shape_interface = + op.dyn_cast(); + if (infer_symbolic_shape_interface) { + VLOG(3) << op.name() << " has InferSymbolicShapeInterface."; + PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape( + &shape_analysis)); } - DebugPrintOpInfo(&op, &shape_analysis); } } @@ -470,9 +415,9 @@ class ShapeOptimizationPass : public pir::Pass { PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) { return pm.Run(m.program()); }; - VLOG(3) << "===================== ShapeOptimizationPass Run End. " "============================="; + PrintProgram(module_op, "ShapeOptimizationPass Program"); } bool CanApplyOn(pir::Operation* op) const override { diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 534801ada8080..d9d7eb3abe186 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -71,6 +71,7 @@ #include "paddle/pir/core/type.h" #include "paddle/pir/core/value.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" #include "paddle/pir/pass/pass_registry.h" @@ -85,7 +86,6 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" #include "paddle/cinn/hlir/framework/pir_compiler.h" #include "paddle/fluid/pir/transforms/build_cinn_pass.h" -#include "paddle/pir/dialect/shape/ir/shape_dialect.h" #endif namespace py = pybind11; @@ -1611,6 +1611,8 @@ void InferSymbolicShapePass( std::shared_ptr &pass_manager, // NOLINT Program &program) { // NOLINT if (FLAGS_pir_apply_shape_optimization_pass) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); pass_manager->AddPass(pir::CreateShapeOptimizationPass()); } } diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 835046c1e7911..6e7d07416f5fc 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -664,6 +664,7 @@ param : [name, shape, dtype] data_type : dtype backend : place + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : depthwise_conv2d args : (Tensor input, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") @@ -1041,6 +1042,7 @@ param : [value, dtype] data_type : dtype backend : place + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : gammaln args : (Tensor x) diff --git a/paddle/pir/dialect/shape/ir/shape_attribute.cc b/paddle/pir/dialect/shape/ir/shape_attribute.cc new file mode 100644 index 0000000000000..c8751f0433ee1 --- /dev/null +++ b/paddle/pir/dialect/shape/ir/shape_attribute.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" + +namespace pir::shape { + +symbol::ShapeOrDataDimExprs SymbolAttribute::data() const { + return storage()->data(); +} + +SymbolAttribute SymbolAttribute::get(pir::IrContext* ctx, + const symbol::ShapeOrDataDimExprs& value) { + return AttributeManager::get(ctx, value); +} + +} // namespace pir::shape + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::SymbolAttribute) diff --git a/paddle/pir/dialect/shape/ir/shape_attribute.h b/paddle/pir/dialect/shape/ir/shape_attribute.h new file mode 100644 index 0000000000000..1eda1ab35f1a7 --- /dev/null +++ b/paddle/pir/dialect/shape/ir/shape_attribute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/utils.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute_storage.h" + +namespace pir::shape { + +class IR_API SymbolAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(SymbolAttribute, SymbolAttributeStorage); + + symbol::ShapeOrDataDimExprs data() const; + + static SymbolAttribute get(IrContext* ctx, + const symbol::ShapeOrDataDimExprs& value); +}; + +} // namespace pir::shape + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::SymbolAttribute) diff --git a/paddle/pir/dialect/shape/ir/shape_attribute_storage.h b/paddle/pir/dialect/shape/ir/shape_attribute_storage.h new file mode 100644 index 0000000000000..11333f6b0d3e2 --- /dev/null +++ b/paddle/pir/dialect/shape/ir/shape_attribute_storage.h @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/common/enforce.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/attribute_base.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/utils.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" + +namespace pir::shape { + +/// +/// \brief Define Parametric AttributeStorage for SymbolAttribute. +/// +struct SymbolAttributeStorage : public AttributeStorage { + using ParamKey = symbol::ShapeOrDataDimExprs; + + explicit SymbolAttributeStorage(const ParamKey &key) : data_(key) {} + + static SymbolAttributeStorage *Construct(const ParamKey &key) { + return new SymbolAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey &key) { + std::size_t hash_value = 0; + for (size_t i = 0; i < key.shape().size(); ++i) { + hash_value = hash_combine( + hash_value, + std::hash()(symbol::ToString(key.shape()[i]))); + } + if (key.data().has_value()) { + for (size_t i = 0; i < key.data().value().size(); ++i) { + hash_value = hash_combine( + hash_value, + std::hash()(symbol::ToString(key.data().value()[i]))); + } + } + + return hash_value; + } + + bool operator==(const ParamKey &key) const { + return data_.shape() == key.shape() && data_.data() == key.data(); + } + + ParamKey data() const { return data_; } + + private: + ParamKey data_; +}; + +} // namespace pir::shape diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.cc b/paddle/pir/dialect/shape/ir/shape_dialect.cc index 0353a7610d2b3..083b0d2bd37c0 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/pir/dialect/shape/ir/shape_dialect.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/pir/dialect/shape/ir/shape_dialect.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" namespace pir::shape { @@ -33,6 +34,37 @@ void ShapeDialect::initialize() { ExtractOp, ConstantOp, IndexCastOp>(); + + RegisterAttributes(); +} + +void ShapeDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const { + if (attr.isa()) { + SymbolAttribute symbol_attr = attr.dyn_cast(); + os << "(shape_data)"; + os << "["; + for (size_t i = 0; i < symbol_attr.data().shape().size(); ++i) { + if (i != symbol_attr.data().shape().size() - 1) { + os << symbol::ToString(symbol_attr.data().shape()[i]) << ","; + } else { + os << symbol::ToString(symbol_attr.data().shape()[i]); + } + } + os << "]_["; + if (symbol_attr.data().data().has_value()) { + for (size_t i = 0; i < symbol_attr.data().data().value().size(); ++i) { + if (i != symbol_attr.data().data().value().size() - 1) { + os << symbol::ToString(symbol_attr.data().data().value()[i]) << ","; + } else { + os << symbol::ToString(symbol_attr.data().data().value()[i]); + } + } + } else { + os << "nullopt"; + } + + os << "]"; + } } void ShapeDialect::PrintOperation(Operation *op, IrPrinter &printer) const { diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.h b/paddle/pir/dialect/shape/ir/shape_dialect.h index 4be71aa0127ce..33b7419c251dd 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.h +++ b/paddle/pir/dialect/shape/ir/shape_dialect.h @@ -23,7 +23,11 @@ namespace pir::shape { class IR_API ShapeDialect : public Dialect { public: explicit ShapeDialect(IrContext* context); + static const char* name() { return "shape"; } + + void PrintAttribute(pir::Attribute type, std::ostream& os) const override; + void PrintOperation(Operation* op, IrPrinter& printer) const override; // NOLINT diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index 4363d50769170..4cacc236ef99d 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -22,9 +22,9 @@ #include #include -#include "paddle/pir/core/dll_decl.h" - #include "glog/logging.h" +#include "paddle/common/enforce.h" +#include "paddle/pir/core/dll_decl.h" namespace symbol { @@ -236,6 +236,22 @@ class ShapeOrData { return ShapeOrData(std::vector{shape}, data); } + static ShapeOrData MakeConsistentShapeOrData( + const ShapeOrData& shape_or_data) { + IR_ENFORCE(shape_or_data.data() == std::nullopt, + "Data of ShapeOrData should be nullopt"); + T shape(std::int64_t(shape_or_data.shape().size())); + return ShapeOrData(std::vector{shape}, shape_or_data.shape()); + } + + int64_t size() const { + if (data_.has_value()) { + return data_.value().size(); + } else { + return shape_.size(); + } + } + // Tensor's real shape const std::vector& shape() const { return shape_; } // Specfic for Tensor generated by shape-relevant ops diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 8f383f3ad6e05..09a2aba1d15f2 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -92,6 +92,9 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { std::unordered_map value_id_to_shapeordata_; + std::unordered_map + value_to_shape_or_data_; + private: // The operation this analysis runs on. ModuleOp m_;