From 0b2487b14a7097092c967e658700ca44550929bc Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Tue, 2 Jan 2024 12:31:41 +0000 Subject: [PATCH 1/3] add helper function MakeGenerateShapeOpAttribute --- .../operator/ir/generate_shape_util.cc | 232 ++++++++++++++++++ .../dialect/operator/ir/generate_shape_util.h | 15 ++ ...e_shape_ops_into_generate_shape_op_pass.cc | 211 +--------------- 3 files changed, 260 insertions(+), 198 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index eef663585a408..df2b55aeb7280 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -15,6 +15,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" +#include namespace cinn::dialect { using namespace symbol; // NOLINT @@ -422,4 +423,235 @@ MakeGetterDimExpr4SymbolName( }; } +namespace { + +bool IsAtomicImpl(int64_t) { return true; } + +bool IsAtomicImpl(const std::string&) { return true; } + +bool IsAtomicImpl(const symbol::Negative&) { return false; } + +bool IsAtomicImpl(const symbol::Reciprocal&) { return false; } + +bool IsAtomicImpl(const symbol::Add&) { return false; } + +bool IsAtomicImpl(const symbol::Mul&) { return false; } + +bool IsAtomicImpl(const symbol::Max&) { return false; } + +bool IsAtomicImpl(const symbol::Min&) { return false; } + +bool IsAtomicImpl(const symbol::Broadcast&) { return false; } + +bool IsAtomic(const symbol::DimExpr& dim_expr) { + return std::visit([](const auto& impl) { return IsAtomicImpl(impl); }, + dim_expr.variant()); +} + +bool InputDimExprsAllSupported( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors) { + const auto& AllSupported = + [](const std::vector& dim_exprs) -> bool { + for (const auto& dim_expr : dim_exprs) { + if (!IsAtomic(dim_expr)) return false; + } + return true; + }; + for (const auto& input_tensor : input_tensors) { + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + if (!AllSupported(dim_exprs.shape())) return false; + if (dim_exprs.data().has_value()) { + if (!AllSupported(dim_exprs.data().value())) return false; + } + } + return true; +} + +void ConvertDimExprToAttributes(pir::IrContext* ir_context, + const std::vector& dim_exprs, + std::vector* attrs) { + attrs->clear(); + attrs->reserve(dim_exprs.size()); + for (const auto& dim_expr : dim_exprs) { + attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr)); + } +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret); + +void CollectSymbolNamesImpl(const int64_t& dim_expr, + std::set* ret) { + // do nothing. +} + +void CollectSymbolNamesImpl(const std::string& dim_expr, + std::set* ret) { + ret->insert(dim_expr); +} + +template +void CollectSymbolNamesImplForUnary(const T& dim_expr, + std::set* ret) { + const auto& [operand] = *dim_expr; + CollectSymbolNames(operand, ret); +} + +void CollectSymbolNamesImpl(const symbol::Negative& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Reciprocal& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +template +void CollectSymbolNamesImplForVariadic(const T& dim_expr, + std::set* ret) { + const auto& operands = *(dim_expr.operands); + for (const auto& operand : operands) { + CollectSymbolNames(operand, ret); + } +} + +void CollectSymbolNamesImpl(const symbol::Add& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Mul& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Max& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Min& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Broadcast& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret) { + return std::visit( + [&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); }, + dim_expr.variant()); +} + +void CollectSymbolNames(const std::vector& dim_exprs, + std::set* ret) { + for (const auto& dim_expr : dim_exprs) { + CollectSymbolNames(dim_expr, ret); + } +} + +template +void AppendSymbolBindings(const std::vector& dim_exprs, + const std::set& symbol_names, + int in_tensor_idx, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); + ++in_tensor_dim_idx) { + const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); + CHECK(IsAtomic(dim_expr)); + if (!dim_expr.isa()) continue; + const auto& sym_name = dim_expr.dyn_cast(); + if (symbol_names.find(sym_name) == symbol_names.end()) continue; + symbol_bindings->emplace_back(SymbolBindingsT{ + /*.symbol_name=*/sym_name, + /*.input_tensor_idx=*/in_tensor_idx, + /*.input_tensor_dim_idx=*/in_tensor_dim_idx, + }); + } +} + +void GenerateSymbolBindings( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors, + const std::set& symbol_names, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int i = 0; i < input_tensors.size(); ++i) { + const auto& input_tensor = input_tensors.at(i); + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + if (dim_exprs.data().has_value()) { + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + } + } +} + +std::vector GetMinimalInputs( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors) { + std::unordered_set handdled_dim_exprs; + std::unordered_set first_occurred_input_tensors; + auto TryCollectFirstOcurredInput_tensor = [&](pir::Value input_tensor, const auto& dim_exprs) { + for (const auto& dim_expr : dim_exprs) { + if (dim_expr.Has()) continue; + if (!handdled_dim_exprs.insert(dim_expr).second) { + first_occurred_input_tensors.insert(input_tensor); + } + } + }; + for (pir::Value input_tensor : input_tensors) { + const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + if (shape_or_data_dim_exprs.data().has_value()) { + TryCollectFirstOcurredInput_tensor( + input_tensor, shape_or_data_dim_exprs.data().value()); + } + TryCollectFirstOcurredInput_tensor( + input_tensor, shape_or_data_dim_exprs.shape()); + } + std::vector ret{}; + ret.reserve(input_tensors.size()); + for (pir::Value input_tensor : input_tensors) { + if (first_occurred_input_tensors.count(input_tensor) > 0) { + ret.emplace_back(input_tensor); + } + } + return ret; +} + +} + +bool MakeGenerateShapeOpAttribute( + pir::IrContext* ir_context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& out_dim_exprs, + const std::vector& origin_inputs, + std::vector* minial_inputs, + std::vector* output_dim_expr_attrs, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + *minial_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs); + if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minial_inputs)) { + VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " + "they are handled by other passes"; + return false; + } + // generate output_dim_expr_attrs + ConvertDimExprToAttributes( + ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs); + // generate symbol_bindings + std::set symbol_names_in_out_dim_exprs{}; + CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs); + GenerateSymbolBindings(ShapeOrDataDimExprs4Value, + *minial_inputs, + symbol_names_in_out_dim_exprs, + /*out*/ symbol_bindings); + return true; +} + } // namespace cinn::dialect diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h index ee4ad3c129e6b..af5d0d9bdf14b 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h @@ -18,6 +18,8 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/dialect/shape/utils/dim_expr.h" +#include +#include namespace cinn::dialect { @@ -46,4 +48,17 @@ MakeGetterDimExpr4SymbolName( const std::function& DimExpr4InputDim); +using ShapeOrDataDimExprs4ValueT = + std::function; + +// Returns true if success. +bool MakeGenerateShapeOpAttribute( + pir::IrContext* ir_context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& out_dim_exprs, + const std::vector& origin_inputs, + std::vector* minial_inputs, + std::vector* output_dim_expr_attrs, + GenerateShapeOp::SymbolBindings* symbol_bindings); + } // namespace cinn::dialect diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 48c7427b402a1..5a43660795d61 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -38,9 +38,6 @@ namespace ir { namespace { -using ShapeOrDataDimExprs4ValueT = - std::function; - std::vector FindSourceDenseTensorOfDimTensor( pir::Value shape, const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { @@ -86,209 +83,26 @@ std::vector FindSourceDenseTensorOfDimTensor( return ret; } -bool IsConstant(const std::vector& dim_exprs) { - for (const auto& dim_expr : dim_exprs) { - if (dim_expr.isa()) continue; - return false; - } - return true; -} - -bool IsAtomicImpl(int64_t) { return true; } - -bool IsAtomicImpl(const std::string&) { return true; } - -bool IsAtomicImpl(const symbol::Negative&) { return false; } - -bool IsAtomicImpl(const symbol::Reciprocal&) { return false; } - -bool IsAtomicImpl(const symbol::Add&) { return false; } - -bool IsAtomicImpl(const symbol::Mul&) { return false; } - -bool IsAtomicImpl(const symbol::Max&) { return false; } - -bool IsAtomicImpl(const symbol::Min&) { return false; } - -bool IsAtomicImpl(const symbol::Broadcast&) { return false; } - -bool IsAtomic(const symbol::DimExpr& dim_expr) { - return std::visit([](const auto& impl) { return IsAtomicImpl(impl); }, - dim_expr.variant()); -} - -bool InputDimExprsAllSupported( - const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, - const std::vector& input_tensors) { - const auto& AllSupported = - [](const std::vector& dim_exprs) -> bool { - for (const auto& dim_expr : dim_exprs) { - if (!IsAtomic(dim_expr)) return false; - } - return true; - }; - for (const auto& input_tensor : input_tensors) { - const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); - if (!AllSupported(dim_exprs.shape())) return false; - if (dim_exprs.data().has_value()) { - if (!AllSupported(dim_exprs.data().value())) return false; - } - } - return true; -} - -void ConvertDimExprToAttributes(pir::IrContext* ir_context, - const std::vector& dim_exprs, - std::vector* attrs) { - attrs->clear(); - attrs->reserve(dim_exprs.size()); - for (const auto& dim_expr : dim_exprs) { - attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr)); - } -} - -void CollectSymbolNames(const symbol::DimExpr& dim_expr, - std::set* ret); - -void CollectSymbolNamesImpl(const int64_t& dim_expr, - std::set* ret) { - // do nothing. -} - -void CollectSymbolNamesImpl(const std::string& dim_expr, - std::set* ret) { - ret->insert(dim_expr); -} - -template -void CollectSymbolNamesImplForUnary(const T& dim_expr, - std::set* ret) { - const auto& [operand] = *dim_expr; - CollectSymbolNames(operand, ret); -} - -void CollectSymbolNamesImpl(const symbol::Negative& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForUnary(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Reciprocal& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForUnary(dim_expr, ret); -} - -template -void CollectSymbolNamesImplForVariadic(const T& dim_expr, - std::set* ret) { - const auto& operands = *(dim_expr.operands); - for (const auto& operand : operands) { - CollectSymbolNames(operand, ret); - } -} - -void CollectSymbolNamesImpl(const symbol::Add& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Mul& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Max& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Min& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Broadcast& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNames(const symbol::DimExpr& dim_expr, - std::set* ret) { - return std::visit( - [&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); }, - dim_expr.variant()); -} - -void CollectSymbolNames(const std::vector& dim_exprs, - std::set* ret) { - for (const auto& dim_expr : dim_exprs) { - CollectSymbolNames(dim_expr, ret); - } -} - -template -void AppendSymbolBindings(const std::vector& dim_exprs, - const std::set& symbol_names, - int in_tensor_idx, - GenerateShapeOp::SymbolBindings* symbol_bindings) { - for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); - ++in_tensor_dim_idx) { - const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); - CHECK(IsAtomic(dim_expr)); - if (!dim_expr.isa()) continue; - const auto& sym_name = dim_expr.dyn_cast(); - if (symbol_names.find(sym_name) == symbol_names.end()) continue; - symbol_bindings->emplace_back(SymbolBindingsT{ - /*.symbol_name=*/sym_name, - /*.input_tensor_idx=*/in_tensor_idx, - /*.input_tensor_dim_idx=*/in_tensor_dim_idx, - }); - } -} - -void GenerateSymbolBindings( - const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, - const std::vector& input_tensors, - const std::set& symbol_names, - GenerateShapeOp::SymbolBindings* symbol_bindings) { - for (int i = 0; i < input_tensors.size(); ++i) { - const auto& input_tensor = input_tensors.at(i); - const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); - AppendSymbolBindings( - dim_exprs.shape(), symbol_names, i, symbol_bindings); - if (dim_exprs.data().has_value()) { - AppendSymbolBindings( - dim_exprs.shape(), symbol_names, i, symbol_bindings); - } - } -} - bool MakeGenerateShapeOpAttribute( pir::IrContext* ir_context, const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, - const std::vector& input_tensors, pir::Value output_shape, + const std::vector& origin_inputs, + std::vector* minimal_inputs, std::vector* output_dim_expr_attrs, GenerateShapeOp::SymbolBindings* symbol_bindings) { const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(output_shape); CHECK(shape_or_data_dim_exprs.data().has_value()); const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value(); - if (IsConstant(out_dim_exprs)) return false; - if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, input_tensors)) { - VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " - "they are handled by other passes"; - return false; - } - // generate output_dim_expr_attrs - ConvertDimExprToAttributes( - ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs); - // generate symbol_bindings - std::set symbol_names_in_out_dim_exprs{}; - CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs); - GenerateSymbolBindings(ShapeOrDataDimExprs4Value, - input_tensors, - symbol_names_in_out_dim_exprs, - /*out*/ symbol_bindings); - return true; + return MakeGenerateShapeOpAttribute( + ir_context, + ShapeOrDataDimExprs4Value, + out_dim_exprs, + origin_inputs, + minimal_inputs, + output_dim_expr_attrs, + symbol_bindings + ); } std::optional GetOutOfRewritedGenerateShapeOp( @@ -302,8 +116,9 @@ std::optional GetOutOfRewritedGenerateShapeOp( GenerateShapeOp::SymbolBindings symbol_bindings{}; bool success = MakeGenerateShapeOpAttribute(rewriter->ir_context(), ShapeOrDataDimExprs4Value, - input_tensors, shape, + /*origin inputs*/input_tensors, + /*minimal inputs*/&input_tensors, &output_dim_expr_attrs, &symbol_bindings); if (!success) return std::nullopt; From 08dd3c081ce6c1729a37b8d73c13fb552c939209 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Tue, 2 Jan 2024 13:48:16 +0000 Subject: [PATCH 2/3] fix complier complaint --- paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index df2b55aeb7280..54a2f988a18e4 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -598,9 +598,9 @@ std::vector GetMinimalInputs( const std::vector& input_tensors) { std::unordered_set handdled_dim_exprs; std::unordered_set first_occurred_input_tensors; - auto TryCollectFirstOcurredInput_tensor = [&](pir::Value input_tensor, const auto& dim_exprs) { + auto TryCollectFirstOcurredInput_tensor = [&](pir::Value input_tensor, const std::vector& dim_exprs) { for (const auto& dim_expr : dim_exprs) { - if (dim_expr.Has()) continue; + if (dim_expr.isa()) continue; if (!handdled_dim_exprs.insert(dim_expr).second) { first_occurred_input_tensors.insert(input_tensor); } From 04481d117344692f18e68fc62ddf59daa411d290 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Wed, 3 Jan 2024 01:55:37 +0000 Subject: [PATCH 3/3] Code format --- .../operator/ir/generate_shape_util.cc | 35 ++++++++++--------- .../dialect/operator/ir/generate_shape_util.h | 4 +-- ...e_shape_ops_into_generate_shape_op_pass.cc | 20 +++++------ 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index 54a2f988a18e4..f64bb9269d63a 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -13,9 +13,9 @@ // limitations under the License. #include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" +#include #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" -#include namespace cinn::dialect { using namespace symbol; // NOLINT @@ -598,22 +598,25 @@ std::vector GetMinimalInputs( const std::vector& input_tensors) { std::unordered_set handdled_dim_exprs; std::unordered_set first_occurred_input_tensors; - auto TryCollectFirstOcurredInput_tensor = [&](pir::Value input_tensor, const std::vector& dim_exprs) { - for (const auto& dim_expr : dim_exprs) { - if (dim_expr.isa()) continue; - if (!handdled_dim_exprs.insert(dim_expr).second) { - first_occurred_input_tensors.insert(input_tensor); - } - } - }; + auto TryCollectFirstOcurredInput_tensor = + [&](pir::Value input_tensor, + const std::vector& dim_exprs) { + for (const auto& dim_expr : dim_exprs) { + if (dim_expr.isa()) continue; + if (!handdled_dim_exprs.insert(dim_expr).second) { + first_occurred_input_tensors.insert(input_tensor); + } + } + }; for (pir::Value input_tensor : input_tensors) { - const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + const auto& shape_or_data_dim_exprs = + ShapeOrDataDimExprs4Value(input_tensor); if (shape_or_data_dim_exprs.data().has_value()) { TryCollectFirstOcurredInput_tensor( - input_tensor, shape_or_data_dim_exprs.data().value()); + input_tensor, shape_or_data_dim_exprs.data().value()); } - TryCollectFirstOcurredInput_tensor( - input_tensor, shape_or_data_dim_exprs.shape()); + TryCollectFirstOcurredInput_tensor(input_tensor, + shape_or_data_dim_exprs.shape()); } std::vector ret{}; ret.reserve(input_tensors.size()); @@ -625,7 +628,7 @@ std::vector GetMinimalInputs( return ret; } -} +} // namespace bool MakeGenerateShapeOpAttribute( pir::IrContext* ir_context, @@ -635,8 +638,8 @@ bool MakeGenerateShapeOpAttribute( std::vector* minial_inputs, std::vector* output_dim_expr_attrs, GenerateShapeOp::SymbolBindings* symbol_bindings) { - *minial_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs); - if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minial_inputs)) { + *minial_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs); + if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minial_inputs)) { VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " "they are handled by other passes"; return false; diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h index af5d0d9bdf14b..401c240f61e86 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h @@ -14,12 +14,12 @@ #pragma once +#include #include +#include #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/dialect/shape/utils/dim_expr.h" -#include -#include namespace cinn::dialect { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 5a43660795d61..0bff6c7daa886 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -94,15 +94,13 @@ bool MakeGenerateShapeOpAttribute( const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(output_shape); CHECK(shape_or_data_dim_exprs.data().has_value()); const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value(); - return MakeGenerateShapeOpAttribute( - ir_context, - ShapeOrDataDimExprs4Value, - out_dim_exprs, - origin_inputs, - minimal_inputs, - output_dim_expr_attrs, - symbol_bindings - ); + return MakeGenerateShapeOpAttribute(ir_context, + ShapeOrDataDimExprs4Value, + out_dim_exprs, + origin_inputs, + minimal_inputs, + output_dim_expr_attrs, + symbol_bindings); } std::optional GetOutOfRewritedGenerateShapeOp( @@ -117,8 +115,8 @@ std::optional GetOutOfRewritedGenerateShapeOp( bool success = MakeGenerateShapeOpAttribute(rewriter->ir_context(), ShapeOrDataDimExprs4Value, shape, - /*origin inputs*/input_tensors, - /*minimal inputs*/&input_tensors, + /*origin inputs*/ input_tensors, + /*minimal inputs*/ &input_tensors, &output_dim_expr_attrs, &symbol_bindings); if (!success) return std::nullopt;