diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index 6d76ccbec8adc..dbe7f3c40adad 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -29,4 +29,13 @@ if(NOT CINN_ONLY) cinn_op_dialect op_dialect_vjp) + cinn_cc_library( + fully_insert_broadcast_pass + SRCS + fully_insert_broadcast_pass.cc + DEPS + pir + cinn_op_dialect + op_dialect_vjp) + endif() diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index e43ae85a7e9c8..968929175b92b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -21,7 +21,8 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp, - paddle::dialect::IncrementOp, paddle::dialect::Increment_Op + paddle::dialect::IncrementOp, paddle::dialect::Increment_Op, + paddle::dialect::ShapeBroadcastOp #else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" @@ -33,6 +34,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" +#include "paddle/phi/api/lib/data_type_set.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" @@ -46,7 +48,6 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/ir_context.h" -#include "paddle/phi/api/lib/data_type_set.h" namespace paddle { namespace dialect { @@ -2685,11 +2686,12 @@ phi::DataType Increment_Op::GetKernelTypeForVar( return expected_kernel_dtype; } -void ShapeBroadcastOp::Build(pir::Builder &builder, pir::OperationArgument &argument, pir::Value x_, pir::Value y_) { +void ShapeBroadcastOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + pir::Value y_) { VLOG(4) << "Start build ShapeBroadcastOp"; - - VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = {x_, y_}; argument.AddInputs(argument_inputs); @@ -2697,24 +2699,28 @@ void ShapeBroadcastOp::Build(pir::Builder &builder, pir::OperationArgument &argu VLOG(4) << "Builder construction attributes"; VLOG(4) << "Builder construction outputs"; - paddle::dialect::DenseTensorType x = x_.type().dyn_cast(); - paddle::dialect::DenseTensorType y = y_.type().dyn_cast(); + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + paddle::dialect::DenseTensorType y = + y_.type().dyn_cast(); VLOG(4) << "Builder construction dense_x"; - paddle::dialect::IrTensor ir_tensor_x(paddle::dialect::TransToPhiDataType(x.dtype()), - x.dims(), - x.data_layout(), - x.lod(), - x.offset()); + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); VLOG(4) << "Builder construction meta_x"; paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); VLOG(4) << "Builder construction dense_y"; - paddle::dialect::IrTensor ir_tensor_y(paddle::dialect::TransToPhiDataType(y.dtype()), - y.dims(), - y.data_layout(), - y.lod(), - y.offset()); + paddle::dialect::IrTensor ir_tensor_y( + paddle::dialect::TransToPhiDataType(y.dtype()), + y.dims(), + y.data_layout(), + y.lod(), + y.offset()); VLOG(4) << "Builder construction meta_y"; paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y); paddle::dialect::IrTensor dense_out; @@ -2723,18 +2729,23 @@ void ShapeBroadcastOp::Build(pir::Builder &builder, pir::OperationArgument &argu phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out); std::vector argument_outputs; - pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_out.dtype()), dense_out.dims(), dense_out.layout(), dense_out.lod(), dense_out.offset()); + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); argument_outputs.push_back(out_dense_tensor_type); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); - } namespace { -void ShapeBroadcastOpInferMeta(const phi::MetaTensor& x, - const phi::MetaTensor& y, - phi::MetaTensor* out) { +void ShapeBroadcastOpInferMeta(const phi::MetaTensor &x, + const phi::MetaTensor &y, + phi::MetaTensor *out) { PADDLE_ENFORCE_EQ(x.dims().size(), 1); PADDLE_ENFORCE_EQ(y.dims().size(), 1); out->set_dims({std::max(x.dims().at(0), y.dims().at(0))}); @@ -2750,23 +2761,73 @@ void ShapeBroadcastOpInferMeta(const phi::MetaTensor& x, out->share_lod(x); } -} +} // namespace -void ShapeBroadcastOp::InferMeta( phi::InferMetaContext *infer_meta ) { +void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) { auto fn = PD_INFER_META(ShapeBroadcastOpInferMeta); fn(infer_meta); } - phi::DataType ShapeBroadcastOp::GetKernelTypeForVar( - const std::string& var_name, - const phi::DataType& tensor_dtype, - const phi::DataType& expected_kernel_dtype) { + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype) { VLOG(4) << "Get KernelType for Var of op: ShapeBroadcastOp"; return expected_kernel_dtype; } +namespace { + +symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs, + const symbol::DimExpr &rhs) { + if (lhs.isa() && rhs.isa()) { + CHECK_EQ(lhs.dyn_cast(), rhs.dyn_cast()); + } else if (lhs.isa()) { + return lhs.dyn_cast() == 1 ? rhs : lhs; + } else if (rhs.isa()) { + return rhs.dyn_cast() == 1 ? lhs : rhs; + } else { + return symbol::Broadcast{ + symbol::List{lhs, rhs}}; + } +} + +} // namespace + +bool ShapeBroadcastOp::InferSymbolicShape( + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value x = operand_source(0); + pir::Value y = operand_source(1); + std::string x_id = pir::GetValueId(&x); + std::string y_id = pir::GetValueId(&y); + + IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0, + "x_id does not exist."); + IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0, + "y_id does not exist."); + const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id); + const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id); + IR_ENFORCE(x_data_shape.data().has_value(), + "Value x comes from ShapeOp, it must have data"); + IR_ENFORCE(y_data_shape.data().has_value(), + "Value y comes from ShapeOp, it must have data"); + const auto &x_data = x_data_shape.data().value(); + const auto &y_data = y_data_shape.data().value(); + IR_ENFORCE(x_data.size() == y_data.size(), "Support same rank temporarily"); + + std::vector output_data; + for (std::size_t i = 0; i < x_data.size(); ++i) { + output_data.emplace_back(GetBroadcastDimExpr(x_data.at(i), y_data.at(i))); + } + + pir::OpResult res = result(0); + std::string res_id = pir::GetValueId(&res); + symbol::ShapeOrDataDimExprs output_data_shape = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data); + shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape; + return true; +} } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index a87b27760f920..71741e550889f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" @@ -515,16 +516,22 @@ class Increment_Op const std::vector> &stop_gradients); }; -class IR_API ShapeBroadcastOp : public pir::Op { +class IR_API ShapeBroadcastOp + : public pir::Op { public: using Op::Op; static const char *name() { return "pd_op.shape_broadcast"; } static constexpr const char **attributes_name = nullptr; static constexpr uint32_t attributes_num = 0; - static void Build(pir::Builder &builder, pir::OperationArgument &argument, pir::Value x_, pir::Value y_); - - void VerifySig() {} + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, + pir::Value y_); + void VerifySig() {} pir::Value x() { return operand_source(0); } pir::Value y() { return operand_source(1); } @@ -533,10 +540,11 @@ class IR_API ShapeBroadcastOp : public pir::Op