diff --git a/paddle/pir/core/type.cc b/paddle/pir/core/type.cc index fef0eb9c1a443..91933019fb835 100644 --- a/paddle/pir/core/type.cc +++ b/paddle/pir/core/type.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/pir/core/type.h" +#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/dialect.h" #include "paddle/pir/core/type_base.h" @@ -24,4 +25,10 @@ TypeId Type::type_id() { return storage_->abstract_type().type_id(); } const AbstractType &Type::abstract_type() { return storage_->abstract_type(); } Dialect &Type::dialect() const { return storage_->abstract_type().dialect(); } + +bool Type::IsIntOrIndex() const { + return isa() || isa() || isa() || + isa() || isa() || isa(); +} + } // namespace pir diff --git a/paddle/pir/core/type.h b/paddle/pir/core/type.h index 0c2cb9d6bc7fa..ba633343749c0 100644 --- a/paddle/pir/core/type.h +++ b/paddle/pir/core/type.h @@ -39,7 +39,7 @@ class IR_API Type { using TypeBase = detail::StorageHelperBase; using Storage = TypeStorage; @@ -116,6 +116,12 @@ class IR_API Type { return pir::cast(*this); } + /// + /// \brief Return true if this is an integer (any signedness) or an index + /// type. + /// + bool IsIntOrIndex() const; + protected: const Storage *storage_{nullptr}; diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.h b/paddle/pir/dialect/shape/transforms/passes.h similarity index 95% rename from paddle/pir/dialect/shape/transforms/shape_optimization_pass.h rename to paddle/pir/dialect/shape/transforms/passes.h index 43bad532c920d..9433ef9b570bd 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.h +++ b/paddle/pir/dialect/shape/transforms/passes.h @@ -21,6 +21,7 @@ namespace pir { class Pass; +// Apply some shape-related optimization. IR_API std::unique_ptr CreateShapeOptimizationPass(); } // namespace pir diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization.cc b/paddle/pir/dialect/shape/transforms/shape_optimization.cc index 959d098675b29..767353efdbc5f 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization.cc @@ -12,10 +12,127 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/dialect/shape/transforms/shape_optimization.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/dialect/shape/ir/shape_op.h" + +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/program.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pass/pass_registry.h" namespace pir { +namespace { +using PassPipelineRunner = + std::function; + +bool InsertTieShapeOnValue(pir::Value value, + pir::Builder& builder) { // NOLINT + auto ty = value.type().dyn_cast(); + + if (!ty || ty.dims().size() == 0) return true; + std::vector dimSizes; + for (int64_t dim = 0, rank = ty.dims().size(); dim < rank; ++dim) { + auto dimOp = builder.Build(value, dim); + dimSizes.push_back(dimOp.out()); + } + builder.Build(value, dimSizes); + return true; +} + +bool InsertTieShapeOnRegion(pir::Region* region); + +bool InsertTieShapeOnOperation(pir::Operation* op, + pir::Builder& builder) { // NOLINT + // TODO(zhangbo63): skip more specialized Ops. + if (op->isa() || op->isa()) + return true; + + for (size_t i = 0; i < op->num_regions(); ++i) { + if (!InsertTieShapeOnRegion(&(op->region(i)))) return false; + } + builder.SetInsertionPointAfter(op); + for (pir::OpResult v : op->results()) { + if (!InsertTieShapeOnValue(v, builder)) return false; + } + + return true; +} + +bool InsertTieShapeOnBlock(pir::Block* block) { + pir::Builder builder = + pir::Builder(pir::IrContext::Instance(), block, block->begin()); + // TODO(liujinnan): mapping block arguments + + std::vector op_list; + for (pir::Operation* op : *block) op_list.push_back(op); + for (pir::Operation* op : op_list) { + if (!InsertTieShapeOnOperation(op, builder)) return false; + } + return true; +} + +bool InsertTieShapeOnRegion(pir::Region* region) { + for (pir::Block* block : *region) { + if (!InsertTieShapeOnBlock(block)) return false; + } + return true; +} + +bool MaterializeShapeComputation(pir::ModuleOp m) { + if (!InsertTieShapeOnRegion(&(m->region(0)))) return false; + // TODO(liujinnan): add rewitter pattern for reifyInferShape. + return true; +} + +bool IsCandidateShapeTensorType(Type type) { + auto tensor_type = type.dyn_cast(); + auto shaped_type = tensor_type.dyn_cast(); + + return (tensor_type && tensor_type && shaped_type.GetRank() == 1 && + shaped_type.HasStaticShape() && + shaped_type.GetElementType().IsIntOrIndex() && + shaped_type.GetShape()[0] < 32); +} + +class ShapeComputationIRAnalysis { + public: + using func = std::function; + explicit ShapeComputationIRAnalysis(ModuleOp m, + SymbolicDimMgr& mgr); // NOLINT + bool Run(); + + private: + bool RunOnRegion(Region* region, func fn); + bool RunOnBlock(Block* block, func fn); + bool RunOnOperation(Operation* op, func fn); + + bool BuildShapeOnOperation(Operation* op); + bool BuildShapeOnValue(Value value); + + bool ApplyOpConstraint(Operation* op); + bool ApplyIndexOpConstraint(Operation* op); + bool ApplyTieShapeOpConstraint(Operation* op); + + bool initialized_ = false; + ModuleOp m_; + SymbolicDimMgr& mgr_; + + std::unordered_map value_to_sym_dim_; + + // shape tensor is the 1D ranked tensor with int/index dtype. + std::unordered_map> shape_tensor_to_sym_dims_; + + std::unordered_map> dense_tensor_to_sym_dims_; +}; + +// Returns true if the type is possible to be a shape tensor type. +// Shape tensor type : +// - rank-1 static-shaped tensor type +// - element type of the tensor is int or index +// - number of elements of the tensor < 32, supposing that the +// higiest possible rank is smaller than 32. ShapeComputationIRAnalysis::ShapeComputationIRAnalysis(ModuleOp m, SymbolicDimMgr& mgr) @@ -25,16 +142,16 @@ bool ShapeComputationIRAnalysis::Run() { // Make sure only run once. if (initialized_) return false; initialized_ = true; - auto buildShapeFunc = + auto build_shape_func = std::bind(&ShapeComputationIRAnalysis::BuildShapeOnOperation, this, std::placeholders::_1); - if (!RunOnRegion(&(m_->region(0)), buildShapeFunc)) return false; - auto applyOpConstraintFunc = + if (!RunOnRegion(&(m_->region(0)), build_shape_func)) return false; + auto apply_op_constraint_func = std::bind(&ShapeComputationIRAnalysis::ApplyOpConstraint, this, std::placeholders::_1); - if (!RunOnRegion(&(m_->region(0)), applyOpConstraintFunc)) return false; + if (!RunOnRegion(&(m_->region(0)), apply_op_constraint_func)) return false; return true; } @@ -90,7 +207,7 @@ bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), ArrayAttribute::get(m_->ir_context(), attrs)); } - rankedTensor2SymDims_[value] = std::move(symbols); + dense_tensor_to_sym_dims_[value] = std::move(symbols); return true; } for (size_t i = 0; i < op->num_results(); ++i) { @@ -101,15 +218,15 @@ bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { bool ShapeComputationIRAnalysis::BuildShapeOnValue(Value value) { Type type = value.type(); - if (IsIntOrIndex(type)) { + if (type.IsIntOrIndex()) { SymbolicDim sym = mgr_.NewSymbolicDim(); - value2SymDim_[value] = sym; + value_to_sym_dim_[value] = sym; } else if (IsCandidateShapeTensorType(type)) { - auto shapedTy = type.dyn_cast(); + auto shaped_type = type.dyn_cast(); std::vector symbols; - for (size_t i = 0, d = shapedTy.GetShape()[0]; i < d; ++i) + for (size_t i = 0, d = shaped_type.GetShape()[0]; i < d; ++i) symbols.push_back(mgr_.NewSymbolicDim()); - shapeTensor2SymDims_[value] = std::move(symbols); + shape_tensor_to_sym_dims_[value] = std::move(symbols); } return true; } @@ -128,24 +245,24 @@ bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { if (op->num_results() == 0) return true; Type type = op->result(0).type(); - if (!IsIntOrIndex(type)) return true; - - if (auto dimOp = op->dyn_cast()) { - int64_t dimIndex = dimOp.index() - .dyn_cast() - .owner() - ->attribute("value") - .data(); - value2SymDim_[dimOp.out()].UpdateKnownNonNegative(true); + if (!type.IsIntOrIndex()) return true; + + if (auto dim_op = op->dyn_cast()) { + int64_t dim_index = dim_op.index() + .dyn_cast() + .owner() + ->attribute("value") + .data(); + value_to_sym_dim_[dim_op.out()].UpdateKnownNonNegative(true); if (!mgr_.MapSymbolicDimEqual( - value2SymDim_[dimOp.out()], - rankedTensor2SymDims_[dimOp.source()][dimIndex])) { + value_to_sym_dim_[dim_op.out()], + dense_tensor_to_sym_dims_[dim_op.source()][dim_index])) { return false; } - } else if (auto constOp = op->dyn_cast()) { - int64_t val = constOp.value().dyn_cast().data(); - if (!mgr_.MapSymbolicDimEqual(value2SymDim_[op->result(0)], + } else if (auto const_op = op->dyn_cast()) { + int64_t val = const_op.value().dyn_cast().data(); + if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[op->result(0)], mgr_.NewConstantSymbolicDim(val))) { return false; } @@ -155,10 +272,10 @@ bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { } bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { - if (auto tieShape = op->dyn_cast()) { - auto& value = rankedTensor2SymDims_[op->operand_source(0)]; - for (size_t idx = 0; idx < tieShape.dims().size(); ++idx) { - if (!mgr_.MapSymbolicDimEqual(value2SymDim_[tieShape.dims()[idx]], + if (auto tie_shape = op->dyn_cast()) { + auto& value = dense_tensor_to_sym_dims_[op->operand_source(0)]; + for (size_t idx = 0; idx < tie_shape.dims().size(); ++idx) { + if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[tie_shape.dims()[idx]], value[idx])) return false; mgr_.GetRootSymbolicDim(value[idx]).UpdateKnownNonNegative(true); @@ -166,4 +283,49 @@ bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { } return true; } + +bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { + // TODO(liujinnan): Do some Canonicalizer. + pir::SymbolicDimMgr mgr(m); + IR_ENFORCE(mgr.Load(), + "SymbolicDimMgr Load failed in OptimizeShapeComputation."); + ShapeComputationIRAnalysis analysis(m, mgr); + if (!analysis.Run()) { + return false; + } + IR_ENFORCE(mgr.Save(), + "SymbolicDimMgr save failed in OptimizeShapeComputation."); + return true; +} + +class ShapeOptimizationPass : public pir::Pass { + public: + ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {} + + void Run(pir::Operation* op) override { + auto module_op = op->dyn_cast(); + IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); + MaterializeShapeComputation(module_op); + // runner is for Canonicalizer. + PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) { + return pm.Run(m.program()); + }; + if (!OptimizeShapeComputation(module_op, runner)) { + return; + } + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->isa() && op->num_regions() > 0; + } +}; + +} // namespace + +std::unique_ptr CreateShapeOptimizationPass() { + return std::make_unique(); +} + } // namespace pir + +REGISTER_IR_PASS(shape_optimization, pir::ShapeOptimizationPass); diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization.h b/paddle/pir/dialect/shape/transforms/shape_optimization.h deleted file mode 100644 index ba711f288a770..0000000000000 --- a/paddle/pir/dialect/shape/transforms/shape_optimization.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pir/dialect/shape/utils/shape_optimization_utils.h" -#include "paddle/pir/dialect/shape/utils/symbol_table.h" - -namespace pir { -class ShapeComputationIRAnalysis { - public: - using func = std::function; - explicit ShapeComputationIRAnalysis(ModuleOp m, - SymbolicDimMgr& mgr); // NOLINT - bool Run(); - - private: - bool RunOnRegion(Region* region, func fn); - bool RunOnBlock(Block* block, func fn); - bool RunOnOperation(Operation* op, func fn); - - bool BuildShapeOnOperation(Operation* op); - bool BuildShapeOnValue(Value value); - - bool ApplyOpConstraint(Operation* op); - bool ApplyIndexOpConstraint(Operation* op); - bool ApplyTieShapeOpConstraint(Operation* op); - - bool initialized_ = false; - ModuleOp m_; - SymbolicDimMgr& mgr_; - - std::unordered_map value2SymDim_; - - // shape tensor is the 1D ranked tensor with int/index dtype. - std::unordered_map> shapeTensor2SymDims_; - - std::unordered_map> rankedTensor2SymDims_; -}; - -} // namespace pir diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc deleted file mode 100644 index f9316f3682aa3..0000000000000 --- a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/pir/dialect/shape/ir/shape_op.h" - -#include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/program.h" -#include "paddle/pir/dialect/shape/transforms/shape_optimization.h" -#include "paddle/pir/dialect/shape/utils/shape_utils.h" -#include "paddle/pir/pass/pass.h" -#include "paddle/pir/pass/pass_manager.h" -#include "paddle/pir/pass/pass_registry.h" - -namespace { -using PassPipelineRunner = - std::function; - -bool InsertTieShapeOnValue(pir::Value value, - pir::Builder& builder) { // NOLINT - auto ty = value.type().dyn_cast(); - - if (!ty || ty.dims().size() == 0) return true; - std::vector dimSizes; - for (int64_t dim = 0, rank = ty.dims().size(); dim < rank; ++dim) { - auto dimOp = builder.Build(value, dim); - dimSizes.push_back(dimOp.out()); - } - builder.Build(value, dimSizes); - return true; -} - -bool InsertTieShapeOnRegion(pir::Region* region); - -bool InsertTieShapeOnOperation(pir::Operation* op, - pir::Builder& builder) { // NOLINT - // TODO(zhangbo63): skip more specialized Ops. - if (op->isa() || op->isa()) - return true; - - for (size_t i = 0; i < op->num_regions(); ++i) { - if (!InsertTieShapeOnRegion(&(op->region(i)))) return false; - } - builder.SetInsertionPointAfter(op); - for (pir::OpResult v : op->results()) { - if (!InsertTieShapeOnValue(v, builder)) return false; - } - - return true; -} - -bool InsertTieShapeOnBlock(pir::Block* block) { - pir::Builder builder = - pir::Builder(pir::IrContext::Instance(), block, block->begin()); - // TODO(liujinnan): mapping block arguments - - std::vector op_list; - for (pir::Operation* op : *block) op_list.push_back(op); - for (pir::Operation* op : op_list) { - if (!InsertTieShapeOnOperation(op, builder)) return false; - } - return true; -} - -bool InsertTieShapeOnRegion(pir::Region* region) { - for (pir::Block* block : *region) { - if (!InsertTieShapeOnBlock(block)) return false; - } - return true; -} - -bool MaterializeShapeComputation(pir::ModuleOp m) { - if (!InsertTieShapeOnRegion(&(m->region(0)))) return false; - // TODO(liujinnan): add rewitter pattern for reifyInferShape. - return true; -} - -bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { - // TODO(liujinnan): Do some Canonicalizer. - pir::SymbolicDimMgr mgr(m); - IR_ENFORCE(mgr.Load(), - "SymbolicDimMgr Load failed in OptimizeShapeComputation."); - pir::ShapeComputationIRAnalysis analysis(m, mgr); - if (!analysis.Run()) { - return false; - } - IR_ENFORCE(mgr.Save(), - "SymbolicDimMgr save failed in OptimizeShapeComputation."); - return true; -} - -class ShapeOptimizationPass : public pir::Pass { - public: - ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {} - - void Run(pir::Operation* op) override { - auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); - MaterializeShapeComputation(module_op); - // runner is for Canonicalizer. - PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) { - return pm.Run(m.program()); - }; - if (!OptimizeShapeComputation(module_op, runner)) { - return; - } - } - - bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; - } -}; - -} // namespace - -namespace pir { - -std::unique_ptr CreateShapeOptimizationPass() { - return std::make_unique(); -} - -} // namespace pir - -REGISTER_IR_PASS(shape_optimization, ShapeOptimizationPass); diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h index fdec957aa6be7..5541e8a8ee2f1 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h @@ -19,21 +19,29 @@ namespace pir { using dialect::SymbolicDim; +// Represents a product of symbolic and concrete factors. +// Used to prove product equalities symbolically. struct SymbolicDimProduct { + // List all symbolic factors that can not be aggregated. std::vector symbols; + + // Product of all const factors. int64_t factor = 1; bool empty() { return factor == 1 && symbols.empty(); } - friend inline bool operator==(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs) { - return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols; - } - - friend inline bool operator!=(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs) { - return !(lhs == rhs); - } }; +// Returns true if two SymbolicDimProduct are equal +inline bool operator==(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols; +} + +// Returns true if two SymbolicDimProduct are not equal +inline bool operator!=(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + return !(lhs == rhs); +} + struct SymDimHasher { size_t operator()(const dialect::SymbolicDim& symbol) const noexcept { return std::hash{}(symbol.operation()); @@ -51,28 +59,67 @@ struct SymProductHasher { } }; +// A class to manage shape-constraint related IR class SymbolicDimMgr { public: explicit SymbolicDimMgr(ModuleOp m); + + // Loads pre-defined SymbolicDim ops from the module this mgr runs on. bool Load(); + + // Create a new symbolicDim instance owned by this mgr. SymbolicDim NewSymbolicDim(const std::string& name = {}); + + // Create a symbolicDim with static dim size == `val`. SymbolicDim NewConstantSymbolicDim(int64_t val); + + // Create a symbolicDim with given value. std::vector CreateSymbolicDimsForRankedValue(Value value); + + // All symbolic-equal dims form a group. + // Returns the root SymbolicDim of the symbolic-equal symbolic dim group which + // this SymbolicDim belongs to. SymbolicDim GetRootSymbolicDim(SymbolicDim symbol); + + // Returns true if lhs and rhs are known to be equal. bool IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + + // Marks lhs and rhs have same size and try to merge lhs & rhs static known + // info. Returns false if failed to merge lhs & rhs. bool MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + + // Returns the simplified version of SymbolicDimProduct. + // This will try to fold some symbolicDim ops with const values. SymbolicDimProduct SimplifySymbolicDimProduct(const SymbolicDimProduct& x); + + // Returns the simplified version of SymbolicDimProductPair. + // This will try to reduce some common symbolic ops if they are known nonzero. std::pair SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, const SymbolicDimProduct& y); + + // Returns null if x is not divided exactly by y, otherwise the result of x / + // y Suppose that all symbols are nonzero, thus common symbolic dim factors + // can be elimiated safely. For example: + // x = 6 * symbol_0 * symbol_1 * symbol_2 + // y = 3 * symbol_0 * symbol_1 + // x / y == 2 * symbol_2 (all symbols are nonzero) SymbolicDimProduct* SymbolicDimProductDivide(const SymbolicDimProduct& x, const SymbolicDimProduct& y); - bool Save(); + + // Mark group [a0, b0, ...] and [a1, b1, ...] are multiplication equal : + // `a0 * b0 * ... = a1 * b1 * c1 * ...` bool IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); + // Mark `product([a0, b0, ...]) == product([a1, b1, c1, ...])` bool MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); + + // Saves the updated shape constraint IR + bool Save(); + + // retuns the SymbolTable. SymbolTable& symbolTable() { return symbol_table_; } private: diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index 4e4c87ed30f86..2b130f73f6d07 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -129,20 +129,4 @@ bool ShapeConstraintIRAnalysis::IsProductEqual(Value lhs, return mgr_.IsSymbolicDimProductEqual(lhs_prod, rhs_prod); } -bool IsIntOrIndex(Type type) { - return type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa(); -} - -bool IsCandidateShapeTensorType(Type type) { - if (auto tensorTy = type.dyn_cast()) { - auto shapedTy = tensorTy.dyn_cast(); - return (shapedTy.GetRank() == 1 && shapedTy.HasStaticShape() && - IsIntOrIndex(shapedTy.GetElementType()) && - shapedTy.GetShape()[0] < 32); - } - return false; -} - } // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 72510f8a23c83..0842313962d36 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -48,8 +48,6 @@ class ShapeAnalysis { virtual bool IsSameNumElements(Value lhs, Value rhs); }; -using dialect::SymbolicDim; - // A subclass to impement `ShapeAnalysis` on buffer level. // The implementation is based on shape constraint ir. class ShapeConstraintIRAnalysis : public ShapeAnalysis { @@ -78,9 +76,8 @@ class ShapeConstraintIRAnalysis : public ShapeAnalysis { SymbolicDimMgr mgr_; // Map a ranked memref value to an array of symbolicDims, each represents one // dimension size of the memref value. - std::unordered_map> value_to_sym_dims_; + std::unordered_map> + value_to_sym_dims_; }; -bool IsIntOrIndex(Type type); -bool IsCandidateShapeTensorType(Type ty); } // namespace pir diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index d5fe787de4a80..349d6a32dfa22 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -1,13 +1,22 @@ -cc_test_old( - symbolic_op_test +paddle_test( + shape_op_test SRCS - symbolic_op_test.cc + shape_op_test.cc DEPS pd_op_dialect pir gtest) -cc_test_old( +paddle_test( + shape_struct_test + SRCS + shape_struct_test.cc + DEPS + pd_op_dialect + pir + gtest) + +paddle_test( constraint_pass_test SRCS constraint_pass_test.cc @@ -19,3 +28,9 @@ cc_test_old( set_tests_properties( constraint_pass_test PROPERTIES ENVIRONMENT "FLAGS_enable_new_ir_in_executor=true") + +if(WITH_ONNXRUNTIME AND WIN32) + # Copy onnxruntime for some c++ test in Windows, since the test will + # be build only in CI, so suppose the generator in Windows is Ninja. + copy_onnx(shape_op_test) +endif() diff --git a/test/cpp/pir/shape_dialect/constraint_pass_test.cc b/test/cpp/pir/shape_dialect/constraint_pass_test.cc index f5282727f7250..860bf34a69ac4 100644 --- a/test/cpp/pir/shape_dialect/constraint_pass_test.cc +++ b/test/cpp/pir/shape_dialect/constraint_pass_test.cc @@ -39,8 +39,7 @@ #include "paddle/pir/core/value.h" #include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" -#include "paddle/pir/dialect/shape/transforms/shape_optimization.h" -#include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" +#include "paddle/pir/dialect/shape/transforms/passes.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" @@ -134,8 +133,5 @@ TEST(constraint_pass, shape_computation_run) { EXPECT_TRUE(pm.Run(&program)); pir::SymbolicDimMgr mgr(program.module_op()); EXPECT_TRUE(mgr.Load()); - pir::ShapeComputationIRAnalysis analysis(program.module_op(), mgr); - EXPECT_TRUE(analysis.Run()); - EXPECT_FALSE(analysis.Run()); EXPECT_TRUE(mgr.Save()); } diff --git a/test/cpp/pir/shape_dialect/shape_op_test.cc b/test/cpp/pir/shape_dialect/shape_op_test.cc new file mode 100644 index 0000000000000..9d71e721fe72d --- /dev/null +++ b/test/cpp/pir/shape_dialect/shape_op_test.cc @@ -0,0 +1,201 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/dialect/shape/ir/shape_op.h" +#include +#include +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/builtin_type_interfaces.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/shape/ir/shape_dialect.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" +#include "paddle/pir/dialect/shape/utils/symbol_table.h" + +pir::AttributeMap CreateAttributeMap( + const std::vector &attribute_names, + const std::vector &attributes) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::AttributeMap attr_map; + for (size_t i = 0; i < attribute_names.size(); i++) { + pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); + attr_map.insert( + std::pair(attribute_names[i], attr_value)); + } + return attr_map; +} + +pir::Operation *CreateDenseTensorOp( + pir::IrContext *ctx, + const phi::DDim &dims, + const std::vector &attribute_names, + const std::vector &attributes) { + std::vector op_inputs = {}; + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + std::vector op_output_types = { + paddle::dialect::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset)}; + pir::Operation *op = + pir::Operation::Create(op_inputs, + CreateAttributeMap(attribute_names, attributes), + op_output_types, + pir::OpInfo()); + return op; +} + +TEST(shape_op, dim) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::dialect::DimOp dim_op = builder.Build("S0"); + pir::OpResult res = dim_op.out(); + EXPECT_EQ(dim_op.getName(), "S0"); + dim_op.setName("S1"); + EXPECT_EQ(dim_op.getName(), "S1"); + EXPECT_EQ(res.owner(), dim_op.operation()); + EXPECT_EQ(res.type(), pir::IndexType::get(ctx)); +} + +TEST(shape_op, tie_product_equal) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + pir::SymbolTable symbolt_table(program.module_op()); + + pir::OpResult dim_op0 = builder.Build("S0").out(); + pir::OpResult dim_op1 = builder.Build("S1").out(); + pir::OpResult dim_op2 = builder.Build("S2").out(); + pir::OpResult dim_op3 = builder.Build("S3").out(); + pir::OpResult dim_op4 = builder.Build("S4").out(); + + pir::dialect::TieProductEqualOp tie_product_equal = + builder.Build( + 2, + 3, + std::vector{dim_op0, dim_op1, dim_op2, dim_op3, dim_op4}); + + std::vector lhs = tie_product_equal.lhs(); + std::vector rhs = tie_product_equal.rhs(); + + std::vector lhs_ref{dim_op0, dim_op1}; + std::vector rhs_ref{dim_op2, dim_op3, dim_op4}; + + EXPECT_EQ(symbolt_table.insert(tie_product_equal), "tie_product_equal"); + EXPECT_EQ( + symbolt_table.Lookup("tie_product_equal") + .size(), + static_cast(1)); + EXPECT_EQ(symbolt_table.Lookup( + "tie_product_equal")[0], + tie_product_equal); + EXPECT_EQ(lhs, lhs_ref); + EXPECT_EQ(rhs, rhs_ref); +} + +TEST(shape_op, tie_shape) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Builder builder = pir::Builder(ctx, program.block()); + + auto op = CreateDenseTensorOp( + ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); + pir::OpResult res = op->result(0); + + pir::dialect::TieShapeOp tie_shape_op = + builder.Build(res); + pir::Value tie_shape_op_value = tie_shape_op.value(); + + pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); + pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); + + std::vector new_attrs = {attr_s0, attr_s1}; + + auto array_attr = pir::ArrayAttribute::get(ctx, new_attrs); + tie_shape_op->set_attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), array_attr); + + std::vector arr_attr_vec = + tie_shape_op + ->attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName()) + .AsVector(); + + EXPECT_EQ(tie_shape_op_value, res); + EXPECT_EQ(arr_attr_vec.size(), static_cast(2)); + EXPECT_EQ(arr_attr_vec[0].dyn_cast(), attr_s0); + EXPECT_EQ(arr_attr_vec[1].dyn_cast(), attr_s1); + EXPECT_TRUE(tie_shape_op->HasAttribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName())); +} + +TEST(shape_op, func_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + pir::dialect::FuncOp func_op = builder.Build(); + auto func_block = func_op.block(); + builder.SetInsertionPointToStart(func_block); + builder.Build(pir::Int32Attribute::get(ctx, 2), + pir::Int32Type::get(ctx)); + EXPECT_EQ(func_block, func_op->region(0).front()); + EXPECT_EQ(func_op->region(0).size(), static_cast(1)); + EXPECT_EQ(func_block->size(), static_cast(1)); +} + +TEST(shape_op, tensor_dim) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::Operation *op = CreateDenseTensorOp( + ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); + pir::OpResult res_dense_tensor_value = op->result(0); + + pir::dialect::TensorDimOp tensor_dim_op0 = + builder.Build(res_dense_tensor_value, 0); + pir::OpResult res0 = tensor_dim_op0.out(); + + pir::OpResult index_value = + builder + .Build( + pir::Int64Attribute::get(pir::IrContext::Instance(), 1), + pir::IndexType::get(pir::IrContext::Instance())) + ->result(0); + pir::dialect::TensorDimOp tensor_dim_op1 = + builder.Build(res_dense_tensor_value, + index_value); + pir::OpResult res1 = tensor_dim_op1.out(); + + EXPECT_EQ(res0.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(res1.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(tensor_dim_op0.source(), res_dense_tensor_value); + EXPECT_EQ(tensor_dim_op1.source(), res_dense_tensor_value); + EXPECT_EQ(tensor_dim_op1.index(), index_value); +} diff --git a/test/cpp/pir/shape_dialect/shape_struct_test.cc b/test/cpp/pir/shape_dialect/shape_struct_test.cc new file mode 100644 index 0000000000000..64b58a399a150 --- /dev/null +++ b/test/cpp/pir/shape_dialect/shape_struct_test.cc @@ -0,0 +1,503 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/builtin_type_interfaces.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/shape/ir/shape_dialect.h" +#include "paddle/pir/dialect/shape/ir/shape_op.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" +#include "paddle/pir/dialect/shape/utils/symbol_table.h" + +pir::AttributeMap CreateAttributeMap( + const std::vector &attribute_names, + const std::vector &attributes) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::AttributeMap attr_map; + for (size_t i = 0; i < attribute_names.size(); i++) { + pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); + attr_map.insert( + std::pair(attribute_names[i], attr_value)); + } + return attr_map; +} + +pir::Operation *CreateDenseTensorOp( + pir::IrContext *ctx, + const phi::DDim &dims, + const std::vector &attribute_names, + const std::vector &attributes) { + std::vector op_inputs = {}; + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + std::vector op_output_types = { + paddle::dialect::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset)}; + pir::Operation *op = + pir::Operation::Create(op_inputs, + CreateAttributeMap(attribute_names, attributes), + op_output_types, + pir::OpInfo()); + return op; +} + +TEST(shape_struct_test, symbolic_dim) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::dialect::SymbolicDim sym_dim1 = builder.Build( + "S0", 10, false, false, false, false); + pir::dialect::SymbolicDim sym_dim2 = builder.Build( + "S1", 10, false, false, false, false); + + EXPECT_EQ(sym_dim1.GetDimSize(), 10); + EXPECT_EQ(sym_dim1.GetSymName(), "S0"); + EXPECT_FALSE(sym_dim1.GetKnownNegativeOne()); + EXPECT_FALSE(sym_dim1.GetKnownNonSizeOne()); + EXPECT_FALSE(sym_dim1.GetKnownNonSizeZero()); + EXPECT_FALSE(sym_dim1.GetKnownNonNegative()); + + EXPECT_FALSE(sym_dim1.IsDynamic()); + EXPECT_TRUE(sym_dim1.Merge(sym_dim2)); + + sym_dim1.SetDimSize(20); + sym_dim1.SetSymName("S2"); + sym_dim1.UpdateKnownNegativeOne(true); + sym_dim1.UpdateKnownNonSizeOne(true); + sym_dim1.UpdateKnownNonSizeZero(true); + sym_dim1.UpdateKnownNonNegative(true); + + EXPECT_FALSE(sym_dim1.Merge(sym_dim2)); + + EXPECT_EQ(sym_dim1.GetDimSize(), 20); + EXPECT_EQ(sym_dim1.GetSymName(), "S2"); + EXPECT_TRUE(sym_dim1.GetKnownNegativeOne()); + EXPECT_TRUE(sym_dim1.GetKnownNonSizeOne()); + EXPECT_TRUE(sym_dim1.GetKnownNonSizeZero()); + EXPECT_TRUE(sym_dim1.GetKnownNonNegative()); +} + +TEST(shape_struct_test, symbolic_dim_product) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + pir::dialect::SymbolicDim sym_dim = builder.Build( + "S0", pir::ShapedTypeInterface::kDynamic, false, false, false, false); + pir::SymbolicDimProduct sym_dim_product1; + pir::SymbolicDimProduct sym_dim_product2; + sym_dim_product1.symbols.push_back(sym_dim); + sym_dim_product1.factor *= 10; + EXPECT_EQ(sym_dim_product1.factor, 10); + EXPECT_NE(sym_dim_product1, sym_dim_product2); + EXPECT_FALSE(sym_dim_product1.empty()); +} + +TEST(shape_struct_test, symbolic_dim_table) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + pir::dialect::SymbolicDim sym_dim = builder.Build( + "S0", 10, false, false, false, false); + + pir::SymbolTable symbol_table(program.module_op()); + EXPECT_EQ(symbol_table.insert(sym_dim), "S0"); + EXPECT_EQ(symbol_table.Lookup("S0"), sym_dim); + EXPECT_EQ(symbol_table.getOp(), program.module_op()); + EXPECT_FALSE(symbol_table.Lookup("S1")); +} + +TEST(shape_struct_test, symbolic_dim_mgr_simple) { + /******************************************************/ + /* Mgr simple version, only SymbolicDim related func. */ + /******************************************************/ + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::SymbolicDimMgr sym_dim_mgr(program.module_op()); + pir::dialect::SymbolicDim sym_dim_s0 = sym_dim_mgr.NewSymbolicDim(); + pir::dialect::SymbolicDim sym_dim_s1 = sym_dim_mgr.NewSymbolicDim(); + pir::dialect::SymbolicDim sym_dim_c10 = + sym_dim_mgr.NewConstantSymbolicDim(10); + sym_dim_mgr.MapSymbolicDimEqual(sym_dim_s0, sym_dim_s1); + + auto op = CreateDenseTensorOp( + ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); + pir::Value res = op->result(0); + + std::vector sym_dim_vec = + sym_dim_mgr.CreateSymbolicDimsForRankedValue(res); + + EXPECT_EQ(sym_dim_s0.GetSymName(), "S0"); + EXPECT_EQ(sym_dim_s1.GetSymName(), "S1"); + EXPECT_EQ(sym_dim_s1.GetDimSize(), pir::ShapedTypeInterface::kDynamic); + EXPECT_EQ(sym_dim_c10.GetSymName(), "C10"); + EXPECT_EQ(sym_dim_c10.GetDimSize(), 10); + EXPECT_EQ(sym_dim_vec[0].GetSymName(), "S2"); + EXPECT_EQ(sym_dim_vec[1].GetSymName(), "C2"); + EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("S0"), + sym_dim_s0); + EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C10"), + sym_dim_c10); + EXPECT_EQ(sym_dim_mgr.GetRootSymbolicDim(sym_dim_s1), sym_dim_s0); + EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s1)); + EXPECT_FALSE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_c10)); +} + +TEST(shape_struct_test, symbolic_dim_mgr_complex) { + /***************************************************************/ + /* Mgr with constraintOp, and SymbolicDimProduct related func. */ + /***************************************************************/ + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::SymbolicDimMgr sym_dim_mgr(program.module_op()); + auto func_op = + sym_dim_mgr.symbolTable().getOp()->dyn_cast(); + + pir::Builder builder = pir::Builder(ctx, func_op.block()); + + pir::dialect::SymbolicDim sym_dim_s0 = sym_dim_mgr.NewSymbolicDim("S0"); + pir::dialect::SymbolicDim sym_dim_s1 = sym_dim_mgr.NewSymbolicDim("S1"); + pir::dialect::SymbolicDim sym_dim_s2 = sym_dim_mgr.NewSymbolicDim("S2"); + pir::dialect::SymbolicDim sym_dim_s3 = sym_dim_mgr.NewSymbolicDim("S3"); + pir::dialect::SymbolicDim sym_dim_s4 = sym_dim_mgr.NewSymbolicDim("S4"); + pir::dialect::SymbolicDim sym_dim_s5 = sym_dim_mgr.NewSymbolicDim("S5"); + pir::dialect::SymbolicDim sym_dim_s6 = sym_dim_mgr.NewSymbolicDim("S6"); + pir::dialect::SymbolicDim sym_dim_s7 = sym_dim_mgr.NewSymbolicDim("S7"); + pir::dialect::SymbolicDim sym_dim_s8 = sym_dim_mgr.NewSymbolicDim("S8"); + pir::dialect::SymbolicDim sym_dim_s9 = sym_dim_mgr.NewSymbolicDim("S9"); + pir::dialect::SymbolicDim sym_dim_s10 = sym_dim_mgr.NewSymbolicDim("S10"); + pir::dialect::SymbolicDim sym_dim_s11 = sym_dim_mgr.NewSymbolicDim("S11"); + pir::dialect::SymbolicDim sym_dim_s12 = sym_dim_mgr.NewSymbolicDim("S12"); + pir::dialect::SymbolicDim sym_dim_c10 = + sym_dim_mgr.NewConstantSymbolicDim(10); + pir::dialect::SymbolicDim sym_dim_c20 = + sym_dim_mgr.NewConstantSymbolicDim(20); + + pir::OpResult dim_op_s0 = builder.Build("S0").out(); + pir::OpResult dim_op_s1 = builder.Build("S1").out(); + pir::OpResult dim_op_s2 = builder.Build("S2").out(); + pir::OpResult dim_op_s3 = builder.Build("S3").out(); + pir::OpResult dim_op_s4 = builder.Build("S4").out(); + pir::OpResult dim_op_s5 = builder.Build("S5").out(); + pir::OpResult dim_op_s6 = builder.Build("S6").out(); + pir::OpResult dim_op_s7 = builder.Build("S7").out(); + pir::OpResult dim_op_s8 = builder.Build("S8").out(); + pir::OpResult dim_op_s9 = builder.Build("S9").out(); + pir::OpResult dim_op_s10 = builder.Build("S10").out(); + pir::OpResult dim_op_s11 = builder.Build("S11").out(); + pir::OpResult dim_op_c10 = builder.Build("C10").out(); + pir::OpResult dim_op_c20 = builder.Build("C20").out(); + pir::OpResult constant = + builder + .Build(pir::Int32Attribute::get(ctx, 2), + pir::Int32Type::get(ctx)) + ->result(0); + + // Mark S1 == S2. + builder.Build( + 2, 2, std::vector{constant, dim_op_s1, dim_op_s2, constant}); + // Mark S0 * S1 == S2 * S3, For check S0 == S3. + builder.Build( + 2, + 2, + std::vector{dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3}); + // Mark S4 * S0 * S1 == S2 * S3 * S5, For check S4 == S5. + builder.Build( + 3, + 3, + std::vector{ + dim_op_s4, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s5}); + // For check S6 == C10 * C20. + builder.Build( + 1, 2, std::vector{dim_op_s6, dim_op_c10, dim_op_c20}); + // Mark C10 * S0 * S1 == S2 * S3 * S7, for check C10 == S7. + builder.Build( + 3, + 3, + std::vector{ + dim_op_c10, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s7}); + + // For unsimplify product case: S8 * S9 == S10 * S11 + builder.Build( + 2, + 2, + std::vector{dim_op_s8, dim_op_s9, dim_op_s10, dim_op_s11}); + + auto op = CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic}, + {"op0_attr"}, + {"op0_name"}); + auto op_ = CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + 10, + 20}, + {"op1_attr"}, + {"op1_name"}); + pir::OpResult res = op->result(0); + pir::OpResult res_ = op_->result(0); + + builder.SetInsertionPointToEnd(program.block()); + pir::dialect::TieShapeOp tie_shape_op1 = + builder.Build(res); + pir::dialect::TieShapeOp tie_shape_op2 = + builder.Build(res_); + + pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); + pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); + pir::Attribute attr_s2 = pir::StrAttribute::get(ctx, "S2"); + pir::Attribute attr_s3 = pir::StrAttribute::get(ctx, "S3"); + pir::Attribute attr_s4 = pir::StrAttribute::get(ctx, "S4"); + pir::Attribute attr_s5 = pir::StrAttribute::get(ctx, "S5"); + pir::Attribute attr_s6 = pir::StrAttribute::get(ctx, "S6"); + pir::Attribute attr_s7 = pir::StrAttribute::get(ctx, "S7"); + pir::Attribute attr_s8 = pir::StrAttribute::get(ctx, "S8"); + pir::Attribute attr_s9 = pir::StrAttribute::get(ctx, "S9"); + pir::Attribute attr_s10 = pir::StrAttribute::get(ctx, "S10"); + pir::Attribute attr_s11 = pir::StrAttribute::get(ctx, "S11"); + pir::Attribute attr_c10 = pir::StrAttribute::get(ctx, "C10"); + pir::Attribute attr_c20 = pir::StrAttribute::get(ctx, "C20"); + + std::vector new_attrs1 = { + attr_s0, attr_s1, attr_s2, attr_s3, attr_s4, attr_s5}; + std::vector new_attrs2 = {attr_s6, + attr_s7, + attr_s8, + attr_s9, + attr_s10, + attr_s11, + attr_c10, + attr_c20}; + std::vector new_attrs_ref = { + attr_s0, attr_s1, attr_s1, attr_s0, attr_s2, attr_s2}; + + auto array_attr1 = pir::ArrayAttribute::get(ctx, new_attrs1); + auto array_attr2 = pir::ArrayAttribute::get(ctx, new_attrs2); + auto array_attr_ref = pir::ArrayAttribute::get(ctx, new_attrs_ref); + + tie_shape_op1->set_attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), array_attr1); + tie_shape_op2->set_attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), array_attr2); + + EXPECT_TRUE(sym_dim_mgr.Load()); + + // For check indirect equality: S1 * S4 == S2 * S5 + pir::SymbolicDimProduct sym_dim_product_lhs1; + pir::SymbolicDimProduct sym_dim_product_rhs1; + + sym_dim_product_lhs1.symbols.push_back(sym_dim_s1); + sym_dim_product_lhs1.symbols.push_back(sym_dim_s4); + + sym_dim_product_rhs1.symbols.push_back(sym_dim_s2); + sym_dim_product_rhs1.symbols.push_back(sym_dim_s5); + + // For uncompletely simplied product check: S8 * S9 * S12 == S10 * S11 * S12 + pir::SymbolicDimProduct sym_dim_product_lhs2; + pir::SymbolicDimProduct sym_dim_product_rhs2; + + sym_dim_product_lhs2.symbols.push_back(sym_dim_s8); + sym_dim_product_lhs2.symbols.push_back(sym_dim_s9); + sym_dim_product_lhs2.symbols.push_back(sym_dim_s12); + + sym_dim_product_rhs2.symbols.push_back(sym_dim_s10); + sym_dim_product_rhs2.symbols.push_back(sym_dim_s11); + sym_dim_product_rhs2.symbols.push_back(sym_dim_s12); + + // For check SimplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor = + // 10} + pir::SymbolicDimProduct sym_dim_product_s7; + sym_dim_product_s7.symbols.push_back(sym_dim_s7); + pir::SymbolicDimProduct simplified_product_s7 = + sym_dim_mgr.SimplifySymbolicDimProduct(sym_dim_product_s7); + + // For check SimplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z + pir::SymbolicDimProduct sym_dim_product_pair_lhs; + pir::SymbolicDimProduct sym_dim_product_pair_rhs; + pir::SymbolicDimProduct new_lhs, new_rhs; + sym_dim_product_pair_lhs.symbols.push_back(sym_dim_s4); + sym_dim_product_pair_lhs.symbols.push_back(sym_dim_s1); + sym_dim_product_pair_lhs.symbols.push_back(sym_dim_s2); + sym_dim_product_pair_rhs.symbols.push_back(sym_dim_s1); + sym_dim_product_pair_rhs.symbols.push_back(sym_dim_s2); + sym_dim_product_pair_rhs.symbols.push_back(sym_dim_s3); + + std::tie(new_lhs, new_rhs) = sym_dim_mgr.SimplifySymbolicDimProductPair( + sym_dim_product_pair_lhs, sym_dim_product_pair_rhs); + + // For check SymbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor + // = 2 Sym = {S4}} + pir::SymbolicDimProduct sym_dim_product_div_lhs; + pir::SymbolicDimProduct sym_dim_product_div_rhs; + sym_dim_product_div_lhs.symbols.push_back(sym_dim_s4); + sym_dim_product_div_lhs.symbols.push_back(sym_dim_s1); + sym_dim_product_div_lhs.symbols.push_back(sym_dim_c20); + sym_dim_product_div_rhs.symbols.push_back(sym_dim_s1); + sym_dim_product_div_rhs.symbols.push_back(sym_dim_c10); + + pir::SymbolicDimProduct *divRes = sym_dim_mgr.SymbolicDimProductDivide( + sym_dim_product_div_lhs, sym_dim_product_div_rhs); + + EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s1, sym_dim_s2)); + EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s3)); + EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s4, sym_dim_s5)); + EXPECT_EQ(sym_dim_s6.GetDimSize(), 200); + EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C20"), + sym_dim_c20); + EXPECT_EQ(sym_dim_s7.GetDimSize(), sym_dim_c10.GetDimSize()); + EXPECT_EQ(simplified_product_s7.factor, 10); + EXPECT_EQ(simplified_product_s7.symbols.size(), static_cast(0)); + EXPECT_EQ(new_lhs.symbols.size(), static_cast(1)); + EXPECT_EQ(new_rhs.symbols.size(), static_cast(1)); + EXPECT_EQ(new_lhs.symbols[0], sym_dim_mgr.GetRootSymbolicDim(sym_dim_s4)); + EXPECT_EQ(new_rhs.symbols[0], sym_dim_mgr.GetRootSymbolicDim(sym_dim_s3)); + EXPECT_EQ(divRes->factor, 2); + EXPECT_EQ(divRes->symbols.size(), static_cast(1)); + EXPECT_EQ(divRes->symbols[0], sym_dim_mgr.GetRootSymbolicDim(sym_dim_s4)); + EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimProductEqual(sym_dim_product_lhs1, + sym_dim_product_rhs1)); + EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimProductEqual(sym_dim_product_lhs2, + sym_dim_product_rhs2)); + EXPECT_TRUE(sym_dim_mgr.Save()); + + pir::SymbolicDimMgr sym_dim_mgr_new(program.module_op()); + EXPECT_TRUE(sym_dim_mgr_new.Load()); + + auto attrs = tie_shape_op1.attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName()); + EXPECT_FALSE( + sym_dim_mgr_new.symbolTable().Lookup("S7")); + EXPECT_EQ(sym_dim_mgr_new.symbolTable() + .Lookup("tie_product_equal") + .size(), + static_cast(1)); + + EXPECT_EQ(attrs.AsVector(), array_attr_ref.AsVector()); +} + +TEST(shape_struct_test, shape_analysis) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + pir::dialect::FuncOp func_op = builder.Build(); + + phi::DDim dims_D_2 = {pir::ShapedTypeInterface::kDynamic, 2}; + phi::DDim dims_2_2 = {2, 2}; + phi::DDim dims_D = {pir::ShapedTypeInterface::kDynamic}; + + // same shape with dynamic: value1 == value2 + auto op1 = CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); + auto op2 = CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); + pir::OpResult value1 = op1->result(0); + pir::OpResult value2 = op2->result(0); + + // same shape with static: value3 == value4 + auto op3 = CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); + auto op4 = CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); + pir::OpResult value3 = op3->result(0); + pir::OpResult value4 = op4->result(0); + + // one dimension with dynamic: value5 != value1 != value3 + auto op5 = CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); + pir::OpResult value5 = op5->result(0); + + pir::dialect::TieShapeOp tie_shape_op1 = + builder.Build(value1); + pir::dialect::TieShapeOp tie_shape_op2 = + builder.Build(value2); + pir::dialect::TieShapeOp tie_shape_op3 = + builder.Build(value3); + pir::dialect::TieShapeOp tie_shape_op4 = + builder.Build(value4); + pir::dialect::TieShapeOp tie_shape_op5 = + builder.Build(value5); + + builder.SetInsertionPointToEnd(func_op.block()); + builder.Build("C2", 2, true, false, true, true); + pir::dialect::SymbolicDim sym_dim_s0 = + builder.Build( + "S0", pir::ShapedTypeInterface::kDynamic, false, false, true, true); + pir::dialect::SymbolicDim sym_dim_s1 = + builder.Build( + "S1", pir::ShapedTypeInterface::kDynamic, false, false, true, true); + pir::dialect::SymbolicDim sym_dim_s2 = + builder.Build( + "S2", pir::ShapedTypeInterface::kDynamic, false, false, true, true); + + pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); + pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); + pir::Attribute attr_s2 = pir::StrAttribute::get(ctx, "S2"); + pir::Attribute attr_c2 = pir::StrAttribute::get(ctx, "C2"); + + auto attr_op1 = pir::ArrayAttribute::get(ctx, {attr_s0, attr_c2}); + auto attr_op2 = pir::ArrayAttribute::get(ctx, {attr_s1, attr_c2}); + auto attr_op3 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); + auto attr_op4 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); + auto attr_op5 = pir::ArrayAttribute::get(ctx, {attr_s2}); + + tie_shape_op1->set_attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op1); + tie_shape_op2->set_attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op2); + tie_shape_op3->set_attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op3); + tie_shape_op4->set_attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op4); + tie_shape_op5->set_attribute( + pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op5); + + pir::ShapeConstraintIRAnalysis shape_analysis(program.module_op()); + EXPECT_TRUE(shape_analysis.IsShapeEqual(value3, value4)); + EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value2)); + EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value3)); + EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); + EXPECT_FALSE(shape_analysis.IsShapeEqual(value3, value5)); + EXPECT_TRUE(shape_analysis.IsProductEqual(value1, {1}, value3, {0})); + EXPECT_TRUE(shape_analysis.IsSameNumElements(value4, value3)); + + shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s1); + shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s2); + + EXPECT_TRUE(shape_analysis.IsShapeEqual(value1, value2)); + EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); +} diff --git a/test/cpp/pir/shape_dialect/symbolic_op_test.cc b/test/cpp/pir/shape_dialect/symbolic_op_test.cc deleted file mode 100644 index 0f8ae6e204047..0000000000000 --- a/test/cpp/pir/shape_dialect/symbolic_op_test.cc +++ /dev/null @@ -1,622 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/pir/core/block.h" -#include "paddle/pir/core/builder.h" -#include "paddle/pir/core/builtin_type.h" -#include "paddle/pir/core/builtin_type_interfaces.h" -#include "paddle/pir/core/dialect.h" -#include "paddle/pir/core/ir_context.h" -#include "paddle/pir/core/program.h" -#include "paddle/pir/dialect/shape/ir/shape_dialect.h" -#include "paddle/pir/dialect/shape/ir/shape_op.h" -#include "paddle/pir/dialect/shape/utils/shape_utils.h" -#include "paddle/pir/dialect/shape/utils/symbol_table.h" - -pir::AttributeMap CreateAttributeMap( - const std::vector &attribute_names, - const std::vector &attributes) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} - -pir::Operation *CreateDenseTensorOp( - pir::IrContext *ctx, - const phi::DDim &dims, - const std::vector &attribute_names, - const std::vector &attributes) { - std::vector op_inputs = {}; - pir::Type fp32_dtype = pir::Float32Type::get(ctx); - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - std::vector op_output_types = { - paddle::dialect::DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset)}; - pir::Operation *op = - pir::Operation::Create(op_inputs, - CreateAttributeMap(attribute_names, attributes), - op_output_types, - pir::OpInfo()); - return op; -} - -TEST(assist_struct_test, symbolic_dim) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - pir::Builder builder = pir::Builder(ctx, program.block()); - - pir::dialect::SymbolicDim sym_dim1 = builder.Build( - "S0", 10, false, false, false, false); - pir::dialect::SymbolicDim sym_dim2 = builder.Build( - "S1", 10, false, false, false, false); - - EXPECT_EQ(sym_dim1.GetDimSize(), 10); - EXPECT_EQ(sym_dim1.GetSymName(), "S0"); - EXPECT_FALSE(sym_dim1.GetKnownNegativeOne()); - EXPECT_FALSE(sym_dim1.GetKnownNonSizeOne()); - EXPECT_FALSE(sym_dim1.GetKnownNonSizeZero()); - EXPECT_FALSE(sym_dim1.GetKnownNonNegative()); - - EXPECT_FALSE(sym_dim1.IsDynamic()); - EXPECT_TRUE(sym_dim1.Merge(sym_dim2)); - - sym_dim1.SetDimSize(20); - sym_dim1.SetSymName("S2"); - sym_dim1.UpdateKnownNegativeOne(true); - sym_dim1.UpdateKnownNonSizeOne(true); - sym_dim1.UpdateKnownNonSizeZero(true); - sym_dim1.UpdateKnownNonNegative(true); - - EXPECT_FALSE(sym_dim1.Merge(sym_dim2)); - - EXPECT_EQ(sym_dim1.GetDimSize(), 20); - EXPECT_EQ(sym_dim1.GetSymName(), "S2"); - EXPECT_TRUE(sym_dim1.GetKnownNegativeOne()); - EXPECT_TRUE(sym_dim1.GetKnownNonSizeOne()); - EXPECT_TRUE(sym_dim1.GetKnownNonSizeZero()); - EXPECT_TRUE(sym_dim1.GetKnownNonNegative()); -} - -TEST(assist_struct_test, symbolic_dim_product) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - pir::Builder builder = pir::Builder(ctx, program.block()); - pir::dialect::SymbolicDim symDim = builder.Build( - "S0", pir::ShapedTypeInterface::kDynamic, false, false, false, false); - pir::SymbolicDimProduct symDimProduct; - pir::SymbolicDimProduct symDimProduct_; - symDimProduct.symbols.push_back(symDim); - symDimProduct.factor *= 10; - EXPECT_EQ(symDimProduct.factor, 10); - EXPECT_NE(symDimProduct, symDimProduct_); - EXPECT_FALSE(symDimProduct.empty()); -} - -TEST(assist_struct_test, symbolic_dim_table) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - pir::Builder builder = pir::Builder(ctx, program.block()); - pir::dialect::SymbolicDim symDim = builder.Build( - "S0", 10, false, false, false, false); - - pir::SymbolTable symbolTable(program.module_op()); - EXPECT_EQ(symbolTable.insert(symDim), "S0"); - EXPECT_EQ(symbolTable.Lookup("S0"), symDim); - EXPECT_EQ(symbolTable.getOp(), program.module_op()); - EXPECT_FALSE(symbolTable.Lookup("S1")); -} - -TEST(assist_struct_test, symbolic_dim_mgr_simple) { - /******************************************************/ - /* Mgr simple version, only SymbolicDim related func. */ - /******************************************************/ - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - pir::SymbolicDimMgr symDimMgr(program.module_op()); - pir::dialect::SymbolicDim symDimS0 = symDimMgr.NewSymbolicDim(); - pir::dialect::SymbolicDim symDimS1 = symDimMgr.NewSymbolicDim(); - pir::dialect::SymbolicDim symDimC10 = symDimMgr.NewConstantSymbolicDim(10); - symDimMgr.MapSymbolicDimEqual(symDimS0, symDimS1); - - auto op = CreateDenseTensorOp( - ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); - pir::Value res = op->result(0); - - std::vector symDimVec = - symDimMgr.CreateSymbolicDimsForRankedValue(res); - - EXPECT_EQ(symDimS0.GetSymName(), "S0"); - EXPECT_EQ(symDimS1.GetSymName(), "S1"); - EXPECT_EQ(symDimS1.GetDimSize(), pir::ShapedTypeInterface::kDynamic); - EXPECT_EQ(symDimC10.GetSymName(), "C10"); - EXPECT_EQ(symDimC10.GetDimSize(), 10); - EXPECT_EQ(symDimVec[0].GetSymName(), "S2"); - EXPECT_EQ(symDimVec[1].GetSymName(), "C2"); - EXPECT_EQ(symDimMgr.symbolTable().Lookup("S0"), - symDimS0); - EXPECT_EQ(symDimMgr.symbolTable().Lookup("C10"), - symDimC10); - EXPECT_EQ(symDimMgr.GetRootSymbolicDim(symDimS1), symDimS0); - EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS0, symDimS1)); - EXPECT_FALSE(symDimMgr.IsSymbolicDimEqual(symDimS0, symDimC10)); -} - -TEST(assist_struct_test, symbolic_dim_mgr_complex) { - /***************************************************************/ - /* Mgr with constraintOp, and SymbolicDimProduct related func. */ - /***************************************************************/ - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - pir::SymbolicDimMgr symDimMgr(program.module_op()); - auto funcOp = - symDimMgr.symbolTable().getOp()->dyn_cast(); - - pir::Builder builder = pir::Builder(ctx, funcOp.block()); - - pir::dialect::SymbolicDim symDimS0 = symDimMgr.NewSymbolicDim("S0"); - pir::dialect::SymbolicDim symDimS1 = symDimMgr.NewSymbolicDim("S1"); - pir::dialect::SymbolicDim symDimS2 = symDimMgr.NewSymbolicDim("S2"); - pir::dialect::SymbolicDim symDimS3 = symDimMgr.NewSymbolicDim("S3"); - pir::dialect::SymbolicDim symDimS4 = symDimMgr.NewSymbolicDim("S4"); - pir::dialect::SymbolicDim symDimS5 = symDimMgr.NewSymbolicDim("S5"); - pir::dialect::SymbolicDim symDimS6 = symDimMgr.NewSymbolicDim("S6"); - pir::dialect::SymbolicDim symDimS7 = symDimMgr.NewSymbolicDim("S7"); - pir::dialect::SymbolicDim symDimS8 = symDimMgr.NewSymbolicDim("S8"); - pir::dialect::SymbolicDim symDimS9 = symDimMgr.NewSymbolicDim("S9"); - pir::dialect::SymbolicDim symDimS10 = symDimMgr.NewSymbolicDim("S10"); - pir::dialect::SymbolicDim symDimS11 = symDimMgr.NewSymbolicDim("S11"); - pir::dialect::SymbolicDim symDimS12 = symDimMgr.NewSymbolicDim("S12"); - pir::dialect::SymbolicDim symDimC10 = symDimMgr.NewConstantSymbolicDim(10); - pir::dialect::SymbolicDim symDimC20 = symDimMgr.NewConstantSymbolicDim(20); - - pir::OpResult dimOpS0 = builder.Build("S0").out(); - pir::OpResult dimOpS1 = builder.Build("S1").out(); - pir::OpResult dimOpS2 = builder.Build("S2").out(); - pir::OpResult dimOpS3 = builder.Build("S3").out(); - pir::OpResult dimOpS4 = builder.Build("S4").out(); - pir::OpResult dimOpS5 = builder.Build("S5").out(); - pir::OpResult dimOpS6 = builder.Build("S6").out(); - pir::OpResult dimOpS7 = builder.Build("S7").out(); - pir::OpResult dimOpS8 = builder.Build("S8").out(); - pir::OpResult dimOpS9 = builder.Build("S9").out(); - pir::OpResult dimOpS10 = builder.Build("S10").out(); - pir::OpResult dimOpS11 = builder.Build("S11").out(); - pir::OpResult dimOpC10 = builder.Build("C10").out(); - pir::OpResult dimOpC20 = builder.Build("C20").out(); - pir::OpResult constant = - builder - .Build(pir::Int32Attribute::get(ctx, 2), - pir::Int32Type::get(ctx)) - ->result(0); - - // Mark S1 == S2. - builder.Build( - 2, 2, std::vector{constant, dimOpS1, dimOpS2, constant}); - // Mark S0 * S1 == S2 * S3, For check S0 == S3. - builder.Build( - 2, 2, std::vector{dimOpS0, dimOpS1, dimOpS2, dimOpS3}); - // Mark S4 * S0 * S1 == S2 * S3 * S5, For check S4 == S5. - builder.Build( - 3, - 3, - std::vector{ - dimOpS4, dimOpS0, dimOpS1, dimOpS2, dimOpS3, dimOpS5}); - // For check S6 == C10 * C20. - builder.Build( - 1, 2, std::vector{dimOpS6, dimOpC10, dimOpC20}); - // Mark C10 * S0 * S1 == S2 * S3 * S7, for check C10 == S7. - builder.Build( - 3, - 3, - std::vector{ - dimOpC10, dimOpS0, dimOpS1, dimOpS2, dimOpS3, dimOpS7}); - - // For unsimplify product case: S8 * S9 == S10 * S11 - builder.Build( - 2, 2, std::vector{dimOpS8, dimOpS9, dimOpS10, dimOpS11}); - - auto op = CreateDenseTensorOp(ctx, - {pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic}, - {"op0_attr"}, - {"op0_name"}); - auto op_ = CreateDenseTensorOp(ctx, - {pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - 10, - 20}, - {"op1_attr"}, - {"op1_name"}); - pir::OpResult res = op->result(0); - pir::OpResult res_ = op_->result(0); - - builder.SetInsertionPointToEnd(program.block()); - pir::dialect::TieShapeOp tieShapeOp = - builder.Build(res); - pir::dialect::TieShapeOp tieShapeOp_ = - builder.Build(res_); - - pir::Attribute attrS0 = pir::StrAttribute::get(ctx, "S0"); - pir::Attribute attrS1 = pir::StrAttribute::get(ctx, "S1"); - pir::Attribute attrS2 = pir::StrAttribute::get(ctx, "S2"); - pir::Attribute attrS3 = pir::StrAttribute::get(ctx, "S3"); - pir::Attribute attrS4 = pir::StrAttribute::get(ctx, "S4"); - pir::Attribute attrS5 = pir::StrAttribute::get(ctx, "S5"); - pir::Attribute attrS6 = pir::StrAttribute::get(ctx, "S6"); - pir::Attribute attrS7 = pir::StrAttribute::get(ctx, "S7"); - pir::Attribute attrS8 = pir::StrAttribute::get(ctx, "S8"); - pir::Attribute attrS9 = pir::StrAttribute::get(ctx, "S9"); - pir::Attribute attrS10 = pir::StrAttribute::get(ctx, "S10"); - pir::Attribute attrS11 = pir::StrAttribute::get(ctx, "S11"); - pir::Attribute attrC10 = pir::StrAttribute::get(ctx, "C10"); - pir::Attribute attrC20 = pir::StrAttribute::get(ctx, "C20"); - - std::vector newAttrs = { - attrS0, attrS1, attrS2, attrS3, attrS4, attrS5}; - std::vector newAttrsRef = { - attrS0, attrS1, attrS1, attrS0, attrS2, attrS2}; - std::vector newAttrs_ = { - attrS6, attrS7, attrS8, attrS9, attrS10, attrS11, attrC10, attrC20}; - - auto arrayAttr = pir::ArrayAttribute::get(ctx, newAttrs); - auto arrayAttrRef = pir::ArrayAttribute::get(ctx, newAttrsRef); - auto arrayAttr_ = pir::ArrayAttribute::get(ctx, newAttrs_); - tieShapeOp->set_attribute(pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), - arrayAttr); - tieShapeOp_->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), arrayAttr_); - - EXPECT_TRUE(symDimMgr.Load()); - - // For check indirect equality: S1 * S4 == S2 * S5 - pir::SymbolicDimProduct symDimProductLhs; - pir::SymbolicDimProduct symDimProductRhs; - - symDimProductLhs.symbols.push_back(symDimS1); - symDimProductLhs.symbols.push_back(symDimS4); - - symDimProductRhs.symbols.push_back(symDimS2); - symDimProductRhs.symbols.push_back(symDimS5); - - // For uncompletely simplied product check: S8 * S9 * S12 == S10 * S11 * S12 - pir::SymbolicDimProduct symDimProductLhs_; - pir::SymbolicDimProduct symDimProductRhs_; - - symDimProductLhs_.symbols.push_back(symDimS8); - symDimProductLhs_.symbols.push_back(symDimS9); - symDimProductLhs_.symbols.push_back(symDimS12); - - symDimProductRhs_.symbols.push_back(symDimS10); - symDimProductRhs_.symbols.push_back(symDimS11); - symDimProductRhs_.symbols.push_back(symDimS12); - - // For check SimplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor = - // 10} - pir::SymbolicDimProduct symDimProductS7; - symDimProductS7.symbols.push_back(symDimS7); - pir::SymbolicDimProduct simplifiedProductS7 = - symDimMgr.SimplifySymbolicDimProduct(symDimProductS7); - - // For check SimplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z - pir::SymbolicDimProduct symDimProductPairLhs; - pir::SymbolicDimProduct symDimProductPairRhs; - pir::SymbolicDimProduct newLhs, newRhs; - symDimProductPairLhs.symbols.push_back(symDimS4); - symDimProductPairLhs.symbols.push_back(symDimS1); - symDimProductPairLhs.symbols.push_back(symDimS2); - symDimProductPairRhs.symbols.push_back(symDimS1); - symDimProductPairRhs.symbols.push_back(symDimS2); - symDimProductPairRhs.symbols.push_back(symDimS3); - - std::tie(newLhs, newRhs) = symDimMgr.SimplifySymbolicDimProductPair( - symDimProductPairLhs, symDimProductPairRhs); - - // For check SymbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor - // = 2 Sym = {S4}} - pir::SymbolicDimProduct symDimProductDivLhs; - pir::SymbolicDimProduct symDimProductDivRhs; - symDimProductDivLhs.symbols.push_back(symDimS4); - symDimProductDivLhs.symbols.push_back(symDimS1); - symDimProductDivLhs.symbols.push_back(symDimC20); - symDimProductDivRhs.symbols.push_back(symDimS1); - symDimProductDivRhs.symbols.push_back(symDimC10); - - pir::SymbolicDimProduct *divRes = symDimMgr.SymbolicDimProductDivide( - symDimProductDivLhs, symDimProductDivRhs); - - EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS1, symDimS2)); - EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS0, symDimS3)); - EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS4, symDimS5)); - EXPECT_EQ(symDimS6.GetDimSize(), 200); - EXPECT_EQ(symDimMgr.symbolTable().Lookup("C20"), - symDimC20); - EXPECT_EQ(symDimS7.GetDimSize(), symDimC10.GetDimSize()); - EXPECT_EQ(simplifiedProductS7.factor, 10); - EXPECT_EQ(simplifiedProductS7.symbols.size(), static_cast(0)); - EXPECT_EQ(newLhs.symbols.size(), static_cast(1)); - EXPECT_EQ(newRhs.symbols.size(), static_cast(1)); - EXPECT_EQ(newLhs.symbols[0], symDimMgr.GetRootSymbolicDim(symDimS4)); - EXPECT_EQ(newRhs.symbols[0], symDimMgr.GetRootSymbolicDim(symDimS3)); - EXPECT_EQ(divRes->factor, 2); - EXPECT_EQ(divRes->symbols.size(), static_cast(1)); - EXPECT_EQ(divRes->symbols[0], symDimMgr.GetRootSymbolicDim(symDimS4)); - EXPECT_TRUE( - symDimMgr.IsSymbolicDimProductEqual(symDimProductLhs, symDimProductRhs)); - EXPECT_TRUE(symDimMgr.IsSymbolicDimProductEqual(symDimProductLhs_, - symDimProductRhs_)); - EXPECT_TRUE(symDimMgr.Save()); - - pir::SymbolicDimMgr symDimMgr_(program.module_op()); - EXPECT_TRUE(symDimMgr_.Load()); - auto attrs = tieShapeOp.attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName()); - EXPECT_FALSE( - symDimMgr_.symbolTable().Lookup("S7")); - EXPECT_EQ(symDimMgr_.symbolTable() - .Lookup("tie_product_equal") - .size(), - static_cast(1)); - - EXPECT_EQ(attrs.AsVector(), arrayAttrRef.AsVector()); -} - -TEST(shape_op, dim) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - pir::Builder builder = pir::Builder(ctx, program.block()); - - pir::dialect::DimOp dimOp = builder.Build("S0"); - pir::OpResult res = dimOp.out(); - EXPECT_EQ(dimOp.getName(), "S0"); - dimOp.setName("S1"); - EXPECT_EQ(dimOp.getName(), "S1"); - EXPECT_EQ(res.owner(), dimOp.operation()); - EXPECT_EQ(res.type(), pir::IndexType::get(ctx)); -} - -TEST(shape_op, tie_product_equal) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - pir::Builder builder = pir::Builder(ctx, program.block()); - pir::SymbolTable symbolTable(program.module_op()); - - pir::OpResult dimOp0 = builder.Build("S0").out(); - pir::OpResult dimOp1 = builder.Build("S1").out(); - pir::OpResult dimOp2 = builder.Build("S2").out(); - pir::OpResult dimOp3 = builder.Build("S3").out(); - pir::OpResult dimOp4 = builder.Build("S4").out(); - - pir::dialect::TieProductEqualOp tie_product_equal = - builder.Build( - 2, - 3, - std::vector{dimOp0, dimOp1, dimOp2, dimOp3, dimOp4}); - - std::vector lhs = tie_product_equal.lhs(); - std::vector rhs = tie_product_equal.rhs(); - - std::vector lhs_ref{dimOp0, dimOp1}; - std::vector rhs_ref{dimOp2, dimOp3, dimOp4}; - - EXPECT_EQ(symbolTable.insert(tie_product_equal), "tie_product_equal"); - EXPECT_EQ( - symbolTable.Lookup("tie_product_equal") - .size(), - static_cast(1)); - EXPECT_EQ(symbolTable.Lookup( - "tie_product_equal")[0], - tie_product_equal); - EXPECT_EQ(lhs, lhs_ref); - EXPECT_EQ(rhs, rhs_ref); -} - -TEST(shape_op, tie_shape) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - pir::Builder builder = pir::Builder(ctx, program.block()); - - auto op = CreateDenseTensorOp( - ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); - pir::OpResult res = op->result(0); - - pir::dialect::TieShapeOp tieShapeOp = - builder.Build(res); - pir::Value tieShapeOpValue = tieShapeOp.value(); - - pir::Attribute attrS0 = pir::StrAttribute::get(ctx, "S0"); - pir::Attribute attrS1 = pir::StrAttribute::get(ctx, "S1"); - - std::vector newAttrs = {attrS0, attrS1}; - - auto arrayAttr = pir::ArrayAttribute::get(ctx, newAttrs); - tieShapeOp->set_attribute(pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), - arrayAttr); - - std::vector arrAttrVec = - tieShapeOp - ->attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName()) - .AsVector(); - - EXPECT_EQ(tieShapeOpValue, res); - EXPECT_EQ(arrAttrVec.size(), static_cast(2)); - EXPECT_EQ(arrAttrVec[0].dyn_cast(), attrS0); - EXPECT_EQ(arrAttrVec[1].dyn_cast(), attrS1); - EXPECT_TRUE(tieShapeOp->HasAttribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName())); -} - -TEST(shape_op, func_op) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::dialect::FuncOp funcOp = builder.Build(); - auto funcBlock = funcOp.block(); - builder.SetInsertionPointToStart(funcBlock); - builder.Build(pir::Int32Attribute::get(ctx, 2), - pir::Int32Type::get(ctx)); - EXPECT_EQ(funcBlock, funcOp->region(0).front()); - EXPECT_EQ(funcOp->region(0).size(), static_cast(1)); - EXPECT_EQ(funcBlock->size(), static_cast(1)); -} - -TEST(assist_struct_test, shape_analysis) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::dialect::FuncOp funcOp = builder.Build(); - - phi::DDim dims_D_2 = {pir::ShapedTypeInterface::kDynamic, 2}; - phi::DDim dims_2_2 = {2, 2}; - phi::DDim dims_D = {pir::ShapedTypeInterface::kDynamic}; - - // same shape with dynamic: value1 == value2 - auto op1 = CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); - auto op2 = CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); - pir::OpResult value1 = op1->result(0); - pir::OpResult value2 = op2->result(0); - - // same shape with static: value3 == value4 - auto op3 = CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); - auto op4 = CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); - pir::OpResult value3 = op3->result(0); - pir::OpResult value4 = op4->result(0); - - // one dimension with dynamic: value5 != value1 != value3 - auto op5 = CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); - pir::OpResult value5 = op5->result(0); - - pir::dialect::TieShapeOp tieShapeOp1 = - builder.Build(value1); - pir::dialect::TieShapeOp tieShapeOp2 = - builder.Build(value2); - pir::dialect::TieShapeOp tieShapeOp3 = - builder.Build(value3); - pir::dialect::TieShapeOp tieShapeOp4 = - builder.Build(value4); - pir::dialect::TieShapeOp tieShapeOp5 = - builder.Build(value5); - - builder.SetInsertionPointToEnd(funcOp.block()); - builder.Build("C2", 2, true, false, true, true); - pir::dialect::SymbolicDim symDimS0 = builder.Build( - "S0", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::dialect::SymbolicDim symDimS1 = builder.Build( - "S1", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::dialect::SymbolicDim symDimS2 = builder.Build( - "S2", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - - pir::Attribute attrS0 = pir::StrAttribute::get(ctx, "S0"); - pir::Attribute attrS1 = pir::StrAttribute::get(ctx, "S1"); - pir::Attribute attrS2 = pir::StrAttribute::get(ctx, "S2"); - pir::Attribute attrC2 = pir::StrAttribute::get(ctx, "C2"); - - auto attrOp1 = pir::ArrayAttribute::get(ctx, {attrS0, attrC2}); - auto attrOp2 = pir::ArrayAttribute::get(ctx, {attrS1, attrC2}); - auto attrOp3 = pir::ArrayAttribute::get(ctx, {attrC2, attrC2}); - auto attrOp4 = pir::ArrayAttribute::get(ctx, {attrC2, attrC2}); - auto attrOp5 = pir::ArrayAttribute::get(ctx, {attrS2}); - - tieShapeOp1->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attrOp1); - tieShapeOp2->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attrOp2); - tieShapeOp3->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attrOp3); - tieShapeOp4->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attrOp4); - tieShapeOp5->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attrOp5); - - pir::ShapeConstraintIRAnalysis shapeAnalysis(program.module_op()); - EXPECT_TRUE(shapeAnalysis.IsShapeEqual(value3, value4)); - EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value3)); - EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value5)); - EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value3, value5)); - EXPECT_TRUE(shapeAnalysis.IsProductEqual(value1, {1}, value3, {0})); - EXPECT_TRUE(shapeAnalysis.IsSameNumElements(value4, value3)); - - shapeAnalysis.symbolicDimMgr().MapSymbolicDimEqual(symDimS0, symDimS1); - shapeAnalysis.symbolicDimMgr().MapSymbolicDimEqual(symDimS0, symDimS2); - - EXPECT_TRUE(shapeAnalysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value5)); -} - -TEST(shape_op, tensor_dim) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - pir::Builder builder = pir::Builder(ctx, program.block()); - - pir::Operation *op = CreateDenseTensorOp( - ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); - pir::OpResult resDenseTensorValue = op->result(0); - - pir::dialect::TensorDimOp tensorDimOp0 = - builder.Build(resDenseTensorValue, 0); - pir::OpResult res0 = tensorDimOp0.out(); - - pir::OpResult indexValue = - builder - .Build( - pir::Int64Attribute::get(pir::IrContext::Instance(), 1), - pir::IndexType::get(pir::IrContext::Instance())) - ->result(0); - pir::dialect::TensorDimOp tensorDimOp1 = - builder.Build(resDenseTensorValue, indexValue); - pir::OpResult res1 = tensorDimOp1.out(); - - EXPECT_EQ(res0.type(), pir::IndexType::get(ctx)); - EXPECT_EQ(res1.type(), pir::IndexType::get(ctx)); - EXPECT_EQ(tensorDimOp0.source(), resDenseTensorValue); - EXPECT_EQ(tensorDimOp1.source(), resDenseTensorValue); - EXPECT_EQ(tensorDimOp1.index(), indexValue); -}