Skip to content

Commit

Permalink
Add ShapeAnalysisMgr (#59254)
Browse files Browse the repository at this point in the history
* ShapeAnalysisMgr

* singleton

* UT

* Werror=reorder
  • Loading branch information
zhangbopd authored Nov 27, 2023
1 parent def8cbd commit f4be32e
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 5 deletions.
3 changes: 2 additions & 1 deletion paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ Operation::Operation(const AttributeMap &attributes,
num_results_(num_results),
num_operands_(num_operands),
num_regions_(num_regions),
num_successors_(num_successors) {}
num_successors_(num_successors),
id_(GenerateId()) {}

///
/// \brief op ouput related public interfaces implementation
Expand Down
8 changes: 8 additions & 0 deletions paddle/pir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class IR_API alignas(8) Operation final

void Verify();

uint64_t id() { return id_; }

private:
DISABLE_COPY_AND_ASSIGN(Operation);
Operation(const AttributeMap &attribute,
Expand Down Expand Up @@ -219,10 +221,16 @@ class IR_API alignas(8) Operation final

OpInfo info_;

static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
return ++uid;
}

const uint32_t num_results_ = 0;
const uint32_t num_operands_ = 0;
const uint32_t num_regions_ = 0;
const uint32_t num_successors_ = 0;
const uint64_t id_ = 0;

detail::BlockOperandImpl *block_operands_{nullptr};
Region *regions_{nullptr};
Expand Down
18 changes: 18 additions & 0 deletions paddle/pir/dialect/shape/utils/shape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,22 @@ bool ShapeConstraintIRAnalysis::IsProductEqual(Value lhs,
return mgr_.IsSymbolicDimProductEqual(lhs_prod, rhs_prod);
}

ShapeAnalysisManager& ShapeAnalysisManager::Instance() {
static ShapeAnalysisManager instance;
return instance;
}

ShapeConstraintIRAnalysis& ShapeAnalysisManager::Get(pir::Program* program) {
auto it = tables_.find(program->module_op().operation()->id());

if (it == tables_.end()) {
it = tables_
.emplace(program->module_op().operation()->id(),
ShapeConstraintIRAnalysis(program->module_op()))
.first;
}

return it->second;
}

} // namespace pir
17 changes: 13 additions & 4 deletions paddle/pir/dialect/shape/utils/shape_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
namespace pir {

// Helper class to query and manipulate shape constraint IR on buffer level.
class ShapeAnalysis {
class IR_API ShapeAnalysis {
public:
virtual ~ShapeAnalysis() = default;

Expand Down Expand Up @@ -50,11 +50,10 @@ class ShapeAnalysis {

// A subclass to impement `ShapeAnalysis` on buffer level.
// The implementation is based on shape constraint ir.
class ShapeConstraintIRAnalysis : public ShapeAnalysis {
class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis {
public:
explicit ShapeConstraintIRAnalysis(ModuleOp m);

// auto-save updated shape constriant ir when destroying.
// Auto-save updated shape constriant ir when destroying.
~ShapeConstraintIRAnalysis();

// Returns the `SymbolicDimMgr` this object holds.
Expand All @@ -80,4 +79,14 @@ class ShapeConstraintIRAnalysis : public ShapeAnalysis {
value_to_sym_dims_;
};

class IR_API ShapeAnalysisManager {
public:
static ShapeAnalysisManager& Instance();
ShapeConstraintIRAnalysis& Get(pir::Program* program);

private:
ShapeAnalysisManager() {}
std::unordered_map<uint64_t, ShapeConstraintIRAnalysis> tables_;
};

} // namespace pir
96 changes: 96 additions & 0 deletions test/cpp/pir/shape_dialect/shape_struct_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,99 @@ TEST(shape_struct_test, shape_analysis) {
EXPECT_TRUE(shape_analysis.IsShapeEqual(value1, value2));
EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5));
}

TEST(shape_struct_test, shape_analysis_manager) {
pir::IrContext *ctx = pir::IrContext::Instance();
pir::Program program(ctx);
ctx->GetOrRegisterDialect<pir::shape::ShapeDialect>();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
::pir::Builder builder = ::pir::Builder(ctx, program.block());
pir::shape::FuncOp func_op = builder.Build<pir::shape::FuncOp>();

phi::DDim dims_D_2 = {-1, 2};
phi::DDim dims_2_2 = {2, 2};
phi::DDim dims_D = {-1};

// same shape with dynamic: value1 == value2
auto op1 =
test::CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"});
auto op2 =
test::CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"});
pir::OpResult value1 = op1->result(0);
pir::OpResult value2 = op2->result(0);

// same shape with static: value3 == value4
auto op3 =
test::CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"});
auto op4 =
test::CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"});
pir::OpResult value3 = op3->result(0);
pir::OpResult value4 = op4->result(0);

// one dimension with dynamic: value5 != value1 != value3
auto op5 = test::CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"});
pir::OpResult value5 = op5->result(0);

pir::shape::TieShapeOp tie_shape_op1 =
builder.Build<pir::shape::TieShapeOp>(value1);
pir::shape::TieShapeOp tie_shape_op2 =
builder.Build<pir::shape::TieShapeOp>(value2);
pir::shape::TieShapeOp tie_shape_op3 =
builder.Build<pir::shape::TieShapeOp>(value3);
pir::shape::TieShapeOp tie_shape_op4 =
builder.Build<pir::shape::TieShapeOp>(value4);
pir::shape::TieShapeOp tie_shape_op5 =
builder.Build<pir::shape::TieShapeOp>(value5);

builder.SetInsertionPointToEnd(func_op.block());
builder.Build<pir::shape::SymbolicDimOp>("C2", 2, true, false, true, true);
pir::shape::SymbolicDimOp sym_dim_s0 =
builder.Build<pir::shape::SymbolicDimOp>(
"S0", pir::ShapedTypeInterface::kDynamic, false, false, true, true);
pir::shape::SymbolicDimOp sym_dim_s1 =
builder.Build<pir::shape::SymbolicDimOp>(
"S1", pir::ShapedTypeInterface::kDynamic, false, false, true, true);
pir::shape::SymbolicDimOp sym_dim_s2 =
builder.Build<pir::shape::SymbolicDimOp>(
"S2", pir::ShapedTypeInterface::kDynamic, false, false, true, true);

pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0");
pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1");
pir::Attribute attr_s2 = pir::StrAttribute::get(ctx, "S2");
pir::Attribute attr_c2 = pir::StrAttribute::get(ctx, "C2");

auto attr_op1 = pir::ArrayAttribute::get(ctx, {attr_s0, attr_c2});
auto attr_op2 = pir::ArrayAttribute::get(ctx, {attr_s1, attr_c2});
auto attr_op3 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2});
auto attr_op4 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2});
auto attr_op5 = pir::ArrayAttribute::get(ctx, {attr_s2});

tie_shape_op1->set_attribute(
pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op1);
tie_shape_op2->set_attribute(
pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op2);
tie_shape_op3->set_attribute(
pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op3);
tie_shape_op4->set_attribute(
pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op4);
tie_shape_op5->set_attribute(
pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op5);

auto shape_analysis_mgr = pir::ShapeAnalysisManager::Instance();
pir::ShapeConstraintIRAnalysis &shape_analysis =
shape_analysis_mgr.Get(&program);

EXPECT_TRUE(shape_analysis.IsShapeEqual(value3, value4));
EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value2));
EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value3));
EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5));
EXPECT_FALSE(shape_analysis.IsShapeEqual(value3, value5));
EXPECT_TRUE(shape_analysis.IsProductEqual(value1, {1}, value3, {0}));
EXPECT_TRUE(shape_analysis.IsSameNumElements(value4, value3));

shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s1);
shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s2);

EXPECT_TRUE(shape_analysis.IsShapeEqual(value1, value2));
EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5));
}

0 comments on commit f4be32e

Please sign in to comment.