diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index eef663585a408..f64bb9269d63a 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" +#include #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" @@ -422,4 +423,238 @@ MakeGetterDimExpr4SymbolName( }; } +namespace { + +bool IsAtomicImpl(int64_t) { return true; } + +bool IsAtomicImpl(const std::string&) { return true; } + +bool IsAtomicImpl(const symbol::Negative&) { return false; } + +bool IsAtomicImpl(const symbol::Reciprocal&) { return false; } + +bool IsAtomicImpl(const symbol::Add&) { return false; } + +bool IsAtomicImpl(const symbol::Mul&) { return false; } + +bool IsAtomicImpl(const symbol::Max&) { return false; } + +bool IsAtomicImpl(const symbol::Min&) { return false; } + +bool IsAtomicImpl(const symbol::Broadcast&) { return false; } + +bool IsAtomic(const symbol::DimExpr& dim_expr) { + return std::visit([](const auto& impl) { return IsAtomicImpl(impl); }, + dim_expr.variant()); +} + +bool InputDimExprsAllSupported( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors) { + const auto& AllSupported = + [](const std::vector& dim_exprs) -> bool { + for (const auto& dim_expr : dim_exprs) { + if (!IsAtomic(dim_expr)) return false; + } + return true; + }; + for (const auto& input_tensor : input_tensors) { + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + if (!AllSupported(dim_exprs.shape())) return false; + if (dim_exprs.data().has_value()) { + if (!AllSupported(dim_exprs.data().value())) return false; + } + } + return true; +} + +void ConvertDimExprToAttributes(pir::IrContext* ir_context, + const std::vector& dim_exprs, + std::vector* attrs) { + attrs->clear(); + attrs->reserve(dim_exprs.size()); + for (const auto& dim_expr : dim_exprs) { + attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr)); + } +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret); + +void CollectSymbolNamesImpl(const int64_t& dim_expr, + std::set* ret) { + // do nothing. +} + +void CollectSymbolNamesImpl(const std::string& dim_expr, + std::set* ret) { + ret->insert(dim_expr); +} + +template +void CollectSymbolNamesImplForUnary(const T& dim_expr, + std::set* ret) { + const auto& [operand] = *dim_expr; + CollectSymbolNames(operand, ret); +} + +void CollectSymbolNamesImpl(const symbol::Negative& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Reciprocal& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +template +void CollectSymbolNamesImplForVariadic(const T& dim_expr, + std::set* ret) { + const auto& operands = *(dim_expr.operands); + for (const auto& operand : operands) { + CollectSymbolNames(operand, ret); + } +} + +void CollectSymbolNamesImpl(const symbol::Add& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Mul& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Max& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Min& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Broadcast& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret) { + return std::visit( + [&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); }, + dim_expr.variant()); +} + +void CollectSymbolNames(const std::vector& dim_exprs, + std::set* ret) { + for (const auto& dim_expr : dim_exprs) { + CollectSymbolNames(dim_expr, ret); + } +} + +template +void AppendSymbolBindings(const std::vector& dim_exprs, + const std::set& symbol_names, + int in_tensor_idx, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); + ++in_tensor_dim_idx) { + const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); + CHECK(IsAtomic(dim_expr)); + if (!dim_expr.isa()) continue; + const auto& sym_name = dim_expr.dyn_cast(); + if (symbol_names.find(sym_name) == symbol_names.end()) continue; + symbol_bindings->emplace_back(SymbolBindingsT{ + /*.symbol_name=*/sym_name, + /*.input_tensor_idx=*/in_tensor_idx, + /*.input_tensor_dim_idx=*/in_tensor_dim_idx, + }); + } +} + +void GenerateSymbolBindings( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors, + const std::set& symbol_names, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int i = 0; i < input_tensors.size(); ++i) { + const auto& input_tensor = input_tensors.at(i); + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + if (dim_exprs.data().has_value()) { + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + } + } +} + +std::vector GetMinimalInputs( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors) { + std::unordered_set handdled_dim_exprs; + std::unordered_set first_occurred_input_tensors; + auto TryCollectFirstOcurredInput_tensor = + [&](pir::Value input_tensor, + const std::vector& dim_exprs) { + for (const auto& dim_expr : dim_exprs) { + if (dim_expr.isa()) continue; + if (!handdled_dim_exprs.insert(dim_expr).second) { + first_occurred_input_tensors.insert(input_tensor); + } + } + }; + for (pir::Value input_tensor : input_tensors) { + const auto& shape_or_data_dim_exprs = + ShapeOrDataDimExprs4Value(input_tensor); + if (shape_or_data_dim_exprs.data().has_value()) { + TryCollectFirstOcurredInput_tensor( + input_tensor, shape_or_data_dim_exprs.data().value()); + } + TryCollectFirstOcurredInput_tensor(input_tensor, + shape_or_data_dim_exprs.shape()); + } + std::vector ret{}; + ret.reserve(input_tensors.size()); + for (pir::Value input_tensor : input_tensors) { + if (first_occurred_input_tensors.count(input_tensor) > 0) { + ret.emplace_back(input_tensor); + } + } + return ret; +} + +} // namespace + +bool MakeGenerateShapeOpAttribute( + pir::IrContext* ir_context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& out_dim_exprs, + const std::vector& origin_inputs, + std::vector* minial_inputs, + std::vector* output_dim_expr_attrs, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + *minial_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs); + if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minial_inputs)) { + VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " + "they are handled by other passes"; + return false; + } + // generate output_dim_expr_attrs + ConvertDimExprToAttributes( + ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs); + // generate symbol_bindings + std::set symbol_names_in_out_dim_exprs{}; + CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs); + GenerateSymbolBindings(ShapeOrDataDimExprs4Value, + *minial_inputs, + symbol_names_in_out_dim_exprs, + /*out*/ symbol_bindings); + return true; +} + } // namespace cinn::dialect diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h index ee4ad3c129e6b..401c240f61e86 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h @@ -14,7 +14,9 @@ #pragma once +#include #include +#include #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/dialect/shape/utils/dim_expr.h" @@ -46,4 +48,17 @@ MakeGetterDimExpr4SymbolName( const std::function& DimExpr4InputDim); +using ShapeOrDataDimExprs4ValueT = + std::function; + +// Returns true if success. +bool MakeGenerateShapeOpAttribute( + pir::IrContext* ir_context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& out_dim_exprs, + const std::vector& origin_inputs, + std::vector* minial_inputs, + std::vector* output_dim_expr_attrs, + GenerateShapeOp::SymbolBindings* symbol_bindings); + } // namespace cinn::dialect diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 48c7427b402a1..0bff6c7daa886 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -38,9 +38,6 @@ namespace ir { namespace { -using ShapeOrDataDimExprs4ValueT = - std::function; - std::vector FindSourceDenseTensorOfDimTensor( pir::Value shape, const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { @@ -86,209 +83,24 @@ std::vector FindSourceDenseTensorOfDimTensor( return ret; } -bool IsConstant(const std::vector& dim_exprs) { - for (const auto& dim_expr : dim_exprs) { - if (dim_expr.isa()) continue; - return false; - } - return true; -} - -bool IsAtomicImpl(int64_t) { return true; } - -bool IsAtomicImpl(const std::string&) { return true; } - -bool IsAtomicImpl(const symbol::Negative&) { return false; } - -bool IsAtomicImpl(const symbol::Reciprocal&) { return false; } - -bool IsAtomicImpl(const symbol::Add&) { return false; } - -bool IsAtomicImpl(const symbol::Mul&) { return false; } - -bool IsAtomicImpl(const symbol::Max&) { return false; } - -bool IsAtomicImpl(const symbol::Min&) { return false; } - -bool IsAtomicImpl(const symbol::Broadcast&) { return false; } - -bool IsAtomic(const symbol::DimExpr& dim_expr) { - return std::visit([](const auto& impl) { return IsAtomicImpl(impl); }, - dim_expr.variant()); -} - -bool InputDimExprsAllSupported( - const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, - const std::vector& input_tensors) { - const auto& AllSupported = - [](const std::vector& dim_exprs) -> bool { - for (const auto& dim_expr : dim_exprs) { - if (!IsAtomic(dim_expr)) return false; - } - return true; - }; - for (const auto& input_tensor : input_tensors) { - const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); - if (!AllSupported(dim_exprs.shape())) return false; - if (dim_exprs.data().has_value()) { - if (!AllSupported(dim_exprs.data().value())) return false; - } - } - return true; -} - -void ConvertDimExprToAttributes(pir::IrContext* ir_context, - const std::vector& dim_exprs, - std::vector* attrs) { - attrs->clear(); - attrs->reserve(dim_exprs.size()); - for (const auto& dim_expr : dim_exprs) { - attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr)); - } -} - -void CollectSymbolNames(const symbol::DimExpr& dim_expr, - std::set* ret); - -void CollectSymbolNamesImpl(const int64_t& dim_expr, - std::set* ret) { - // do nothing. -} - -void CollectSymbolNamesImpl(const std::string& dim_expr, - std::set* ret) { - ret->insert(dim_expr); -} - -template -void CollectSymbolNamesImplForUnary(const T& dim_expr, - std::set* ret) { - const auto& [operand] = *dim_expr; - CollectSymbolNames(operand, ret); -} - -void CollectSymbolNamesImpl(const symbol::Negative& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForUnary(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Reciprocal& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForUnary(dim_expr, ret); -} - -template -void CollectSymbolNamesImplForVariadic(const T& dim_expr, - std::set* ret) { - const auto& operands = *(dim_expr.operands); - for (const auto& operand : operands) { - CollectSymbolNames(operand, ret); - } -} - -void CollectSymbolNamesImpl(const symbol::Add& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Mul& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Max& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Min& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Broadcast& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNames(const symbol::DimExpr& dim_expr, - std::set* ret) { - return std::visit( - [&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); }, - dim_expr.variant()); -} - -void CollectSymbolNames(const std::vector& dim_exprs, - std::set* ret) { - for (const auto& dim_expr : dim_exprs) { - CollectSymbolNames(dim_expr, ret); - } -} - -template -void AppendSymbolBindings(const std::vector& dim_exprs, - const std::set& symbol_names, - int in_tensor_idx, - GenerateShapeOp::SymbolBindings* symbol_bindings) { - for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); - ++in_tensor_dim_idx) { - const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); - CHECK(IsAtomic(dim_expr)); - if (!dim_expr.isa()) continue; - const auto& sym_name = dim_expr.dyn_cast(); - if (symbol_names.find(sym_name) == symbol_names.end()) continue; - symbol_bindings->emplace_back(SymbolBindingsT{ - /*.symbol_name=*/sym_name, - /*.input_tensor_idx=*/in_tensor_idx, - /*.input_tensor_dim_idx=*/in_tensor_dim_idx, - }); - } -} - -void GenerateSymbolBindings( - const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, - const std::vector& input_tensors, - const std::set& symbol_names, - GenerateShapeOp::SymbolBindings* symbol_bindings) { - for (int i = 0; i < input_tensors.size(); ++i) { - const auto& input_tensor = input_tensors.at(i); - const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); - AppendSymbolBindings( - dim_exprs.shape(), symbol_names, i, symbol_bindings); - if (dim_exprs.data().has_value()) { - AppendSymbolBindings( - dim_exprs.shape(), symbol_names, i, symbol_bindings); - } - } -} - bool MakeGenerateShapeOpAttribute( pir::IrContext* ir_context, const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, - const std::vector& input_tensors, pir::Value output_shape, + const std::vector& origin_inputs, + std::vector* minimal_inputs, std::vector* output_dim_expr_attrs, GenerateShapeOp::SymbolBindings* symbol_bindings) { const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(output_shape); CHECK(shape_or_data_dim_exprs.data().has_value()); const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value(); - if (IsConstant(out_dim_exprs)) return false; - if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, input_tensors)) { - VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " - "they are handled by other passes"; - return false; - } - // generate output_dim_expr_attrs - ConvertDimExprToAttributes( - ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs); - // generate symbol_bindings - std::set symbol_names_in_out_dim_exprs{}; - CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs); - GenerateSymbolBindings(ShapeOrDataDimExprs4Value, - input_tensors, - symbol_names_in_out_dim_exprs, - /*out*/ symbol_bindings); - return true; + return MakeGenerateShapeOpAttribute(ir_context, + ShapeOrDataDimExprs4Value, + out_dim_exprs, + origin_inputs, + minimal_inputs, + output_dim_expr_attrs, + symbol_bindings); } std::optional GetOutOfRewritedGenerateShapeOp( @@ -302,8 +114,9 @@ std::optional GetOutOfRewritedGenerateShapeOp( GenerateShapeOp::SymbolBindings symbol_bindings{}; bool success = MakeGenerateShapeOpAttribute(rewriter->ir_context(), ShapeOrDataDimExprs4Value, - input_tensors, shape, + /*origin inputs*/ input_tensors, + /*minimal inputs*/ &input_tensors, &output_dim_expr_attrs, &symbol_bindings); if (!success) return std::nullopt;