diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index dbe7f3c40adad6..9b57f4f14edeaf 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -38,4 +38,13 @@ if(NOT CINN_ONLY) cinn_op_dialect op_dialect_vjp) + cinn_cc_library( + split_generate_shape_into_shape_ops_pass + SRCS + split_generate_shape_into_shape_ops_pass.cc + DEPS + pir + cinn_op_dialect + op_dialect_vjp) + endif() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc new file mode 100644 index 00000000000000..cec66f7c70e2e1 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc @@ -0,0 +1,368 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h" + +#include "paddle/cinn/common/dim_expr_simplify.h" +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/pattern_applicator.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace cinn { +namespace dialect { +namespace ir { + +namespace { + +struct TensorDimInShape { + pir::Value value; + int axis; +}; + +struct TensorDimInData { + pir::Value value; + int axis; +}; + +using TensorDim = std::variant; + +using TensorDim4SymbolNameT = + std::function(const std::string& symbol_name)>; + +struct CachedDimExprToValueConverter { + CachedDimExprToValueConverter( + const TensorDim4SymbolNameT& TensorDim4SymbolNameVal, + pir::PatternRewriter* rewriter_val) + : TensorDim4SymbolName(TensorDim4SymbolNameVal), rewriter(rewriter_val) {} + + TensorDim4SymbolNameT TensorDim4SymbolName; + pir::PatternRewriter* rewriter; + + // TODO(): Refactor to cached version if std::hash() is + // ready. std::unordered_map + // symbol_names2cached_value_; + + pir::Value ConvertToValue(const symbol::DimExpr& dim_expr) { + // TODO(): cache the returned value if std::hash() is + // ready + return std::visit( + [&](const auto& impl) { return ConvertToValueImpl(impl); }, + dim_expr.variant()); + } + + pir::Value GetInputShapeByInputTensor(pir::Value input_tensor) { + auto iter = tensor2shape_.find(input_tensor); + if (iter == tensor2shape_.end()) { + pir::Value input_shape = + rewriter->Build(input_tensor).out(); + iter = tensor2shape_.emplace(input_tensor, input_shape).first; + } + return iter->second; + } + + private: + std::unordered_map + tensor2shape_; + + pir::Value ConvertToValueImpl(int64_t dim_expr) { + return rewriter + ->Build(std::vector{dim_expr}, + phi::DataType::INT64) + .out(); + } + + pir::Value ConvertToValueImpl(const std::string& symbol_name) { + const auto& tensor_dim = TensorDim4SymbolName(symbol_name); + PADDLE_ENFORCE( + tensor_dim.has_value(), + phi::errors::PreconditionNotMet( + "symbol [%s] are not bound to any input of generate_shape op", + symbol_name)); + return std::visit( + [&](const auto& impl) { return ConvertTensorDimToValue(impl); }, + tensor_dim.value()); + } + + pir::Value ConvertTensorDimToValue(const TensorDimInShape& tensor_dim) { + pir::Value input_shape = GetInputShapeByInputTensor(tensor_dim.value); + return ConvertTensorDimToValue( + TensorDimInData{.value = input_shape, .axis = tensor_dim.axis}); + } + + pir::Value ConvertTensorDimToValue(const TensorDimInData& tensor_dim) { + return rewriter + ->Build( + tensor_dim.value, + std::vector{0LL}, + std::vector{tensor_dim.axis}, + std::vector{tensor_dim.axis + 1}, + std::vector{}, + std::vector{}) + .out(); + } + + pir::Value ConvertToValueImpl( + const symbol::Negative& dim_expr) { + LOG(FATAL) << "Dead code. This logical should handled by " + "ConvertToValueImpl(symbol::Add)"; + } + + pir::Value ConvertToValueImpl( + const symbol::Reciprocal& dim_expr) { + LOG(FATAL) << "Dead code. This logical should handled by " + "ConvertToValueImpl(symbol::Mul)"; + } + + pir::Value ConvertToValueImpl(const symbol::Add& dim_expr) { + const auto& [operands] = dim_expr; + CHECK_GT(operands->size(), 0); + pir::Value acc = ConvertToValue(operands->at(0)); + for (int i = 1; i < operands->size(); ++i) { + if (operands->at(i).isa>()) { + const auto& [operand] = + *operands->at(i).dyn_cast>(); + pir::Value operand_value = ConvertToValue(operand); + acc = rewriter->Build(acc, operand_value) + .out(); + } else { + pir::Value operand_value = ConvertToValue(operands->at(i)); + acc = rewriter->Build(acc, operand_value).out(); + } + } + return acc; + } + + pir::Value ConvertToValueImpl(const symbol::Mul& dim_expr) { + const auto& [operands] = dim_expr; + CHECK_GT(operands->size(), 0); + pir::Value prod = ConvertToValue(operands->at(0)); + for (int i = 1; i < operands->size(); ++i) { + if (operands->at(i).isa>()) { + const auto& [operand] = + *operands->at(i).dyn_cast>(); + pir::Value operand_value = ConvertToValue(operand); + prod = rewriter->Build(prod, operand_value) + .out(); + } else { + pir::Value operand_value = ConvertToValue(operands->at(i)); + prod = rewriter->Build(prod, operand_value) + .out(); + } + } + return prod; + } + + pir::Value ConvertToValueImpl(const symbol::Max& dim_expr) { + const auto& [operands] = dim_expr; + CHECK_GT(operands->size(), 0); + pir::Value max = ConvertToValue(operands->at(0)); + for (int i = 1; i < operands->size(); ++i) { + pir::Value operand_value = ConvertToValue(operands->at(i)); + max = rewriter->Build(max, operand_value).out(); + } + return max; + } + + pir::Value ConvertToValueImpl(const symbol::Min& dim_expr) { + const auto& [operands] = dim_expr; + CHECK_GT(operands->size(), 0); + pir::Value min = ConvertToValue(operands->at(0)); + for (int i = 1; i < operands->size(); ++i) { + pir::Value operand_value = ConvertToValue(operands->at(i)); + min = rewriter->Build(min, operand_value).out(); + } + return min; + } + + pir::Value ConvertToValueImpl( + const symbol::Broadcast& dim_expr) { + const auto& [operands] = dim_expr; + CHECK_GT(operands->size(), 0); + pir::Value broadcasted = ConvertToValue(operands->at(0)); + for (int i = 1; i < operands->size(); ++i) { + pir::Value operand_value = ConvertToValue(operands->at(i)); + broadcasted = rewriter + ->Build( + broadcasted, operand_value) + .out(); + } + return broadcasted; + } +}; + +} // namespace + +class SplitGenerateShapeIntoShapeOps + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(cinn::dialect::GenerateShapeOp op, + pir::PatternRewriter& rewriter) const override { + std::optional out_replacement = + GetOutReplacement(op, &rewriter); + if (!out_replacement.has_value()) return false; + rewriter.ReplaceAllUsesWith(op->result(0), out_replacement.value()); + return true; + } + + std::optional GetOutReplacement( + cinn::dialect::GenerateShapeOp op, pir::PatternRewriter* rewriter) const { + std::vector dim_exprs = GetOutDimExprs(op); + TensorDim4SymbolNameT TensorDim4SymbolName = + MakeGetterTensorDim4SymbolName(op); + if (!TensorDim4SymbolName) return std::nullopt; + CachedDimExprToValueConverter converter{TensorDim4SymbolName, rewriter}; + return GetValueOfRewritedOps(dim_exprs, &converter); + } + + TensorDim4SymbolNameT MakeGetterTensorDim4SymbolName( + cinn::dialect::GenerateShapeOp op) const { + std::unordered_map symbol_name2tenso_dim{}; + const auto& attr_map = op->attributes(); + const auto& iter = attr_map.find("symbol_bindings"); + PADDLE_ENFORCE((iter != attr_map.end()), + phi::errors::PreconditionNotMet( + "attr symbol_bindings MUST in attribute map for [%s] op", + op->name())); + pir::Attribute attr = iter->second; + auto* Convert = + &cinn::dialect::GenerateShapeOp::ConvertAttributeToSymbolBindings; + const auto& symbol_bindings = Convert(attr); + PADDLE_ENFORCE( + symbol_bindings.has_value(), + phi::errors::PreconditionNotMet("attr symbol_bindings in op [%s] can " + "not be converted to symbol bindings", + op->name())); + for (const auto& symbol_binding : symbol_bindings.value()) { + InsertSymbolBinding(op, symbol_binding, &symbol_name2tenso_dim); + } + return [map = std::move(symbol_name2tenso_dim)]( + const std::string& symbol_name) -> std::optional { + auto iter = map.find(symbol_name); + if (iter == map.end()) return std::nullopt; + return iter->second; + }; + } + + void InsertSymbolBinding( + cinn::dialect::GenerateShapeOp op, + const cinn::dialect::GenerateShapeOp::SymbolBinding& symbol_binding, + std::unordered_map* symbol_name2tenso_dim) const { + return std::visit( + [&](const auto& impl) { + return InsertSymbolBindingImpl(op, impl, symbol_name2tenso_dim); + }, + symbol_binding); + } + + void InsertSymbolBindingImpl( + cinn::dialect::GenerateShapeOp op, + const cinn::dialect::GenerateShapeOp::DataSymbolBinding& symbol_binding, + std::unordered_map* symbol_name2tenso_dim) const { + (*symbol_name2tenso_dim)[symbol_binding.symbol_name] = TensorDimInData{ + .value = op.operand_source(symbol_binding.input_tensor_idx), + .axis = symbol_binding.input_tensor_dim_idx}; + } + + void InsertSymbolBindingImpl( + cinn::dialect::GenerateShapeOp op, + const cinn::dialect::GenerateShapeOp::ShapeSymbolBinding& symbol_binding, + std::unordered_map* symbol_name2tenso_dim) const { + (*symbol_name2tenso_dim)[symbol_binding.symbol_name] = TensorDimInShape{ + .value = op.operand_source(symbol_binding.input_tensor_idx), + .axis = symbol_binding.input_tensor_dim_idx}; + } + + std::vector GetOutDimExprs( + cinn::dialect::GenerateShapeOp op) const { + const auto& attr_map = op->attributes(); + const auto& iter = attr_map.find("output_dim_exprs"); + PADDLE_ENFORCE( + (iter != attr_map.end()), + phi::errors::PreconditionNotMet( + "attr output_dim_exprs MUST in attribute map for [%s] op", + op->name())); + pir::Attribute output_dim_exprs_attr = iter->second; + PADDLE_ENFORCE( + output_dim_exprs_attr.isa(), + phi::errors::PreconditionNotMet( + "attr output_dim_exprs for [%s] op must be an pir::ArrayAttribute", + op->name())); + std::vector ret{}; + const auto& output_dim_exprs = + output_dim_exprs_attr.dyn_cast(); + for (int i = 0; i < output_dim_exprs.size(); ++i) { + const auto& attr = output_dim_exprs.at(i); + const auto& opt_dim_expr = cinn::dialect::ConvertAttributeToDimExpr(attr); + CHECK(opt_dim_expr.has_value()); + ret.emplace_back(opt_dim_expr.value()); + } + return ret; + } + + pir::Value GetValueOfRewritedOps( + const std::vector& dim_exprs, + CachedDimExprToValueConverter* converter) const { + const std::vector& values_from_dim_exprs = + GetValuesOfRewritedOps(dim_exprs, converter); + return converter->rewriter->Build(values_from_dim_exprs) + .out(); + } + + std::vector GetValuesOfRewritedOps( + const std::vector& dim_exprs, + CachedDimExprToValueConverter* converter) const { + std::vector ret; + for (const auto& dim_expr : dim_exprs) { + const auto& simplified = cinn::common::SimplifyDimExpr(dim_expr); + pir::Value value = converter->ConvertToValue(simplified); + ret.push_back(value); + } + return ret; + } +}; + +SplitGenerateShapeIntoShapeOpsPass::SplitGenerateShapeIntoShapeOpsPass() + : pir::PatternRewritePass("split_generate_shape_into_shape_ops_pass", 1) {} + +pir::RewritePatternSet SplitGenerateShapeIntoShapeOpsPass::InitializePatterns( + pir::IrContext* context) { + pir::RewritePatternSet ps(context); + // elementwise ops + ps.Add(context); + return ps; +} + +bool SplitGenerateShapeIntoShapeOpsPass::CanApplyOn(pir::Operation* op) const { + return op->isa() && op->num_regions() > 0; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h new file mode 100644 index 00000000000000..fe0c3f124cd385 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class SplitGenerateShapeIntoShapeOpsPass : public pir::PatternRewritePass { + public: + SplitGenerateShapeIntoShapeOpsPass(); + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override; + + bool CanApplyOn(pir::Operation *op) const override; +}; + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 52eada6c5482ff..08d8a22699e264 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -51,8 +51,13 @@ set(PYBIND_DEPS if(WITH_CINN) set(PYBIND_DEPS - ${PYBIND_DEPS} pir_transforms op_with_group_merge_pass - add_broadcast_to_elementwise_pass pd_to_cinn_pass sub_graph_checker) + ${PYBIND_DEPS} + pir_transforms + op_with_group_merge_pass + add_broadcast_to_elementwise_pass + pd_to_cinn_pass + sub_graph_checker + split_generate_shape_into_shape_ops_pass) endif() if(WITH_PSCORE)