Skip to content

Commit

Permalink
[Dynamic Shape]Fix SubstituteDimExprBasedOnConstraintsPass invalid bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 committed Mar 8, 2024
1 parent c8cd35d commit b4d2c6c
Showing 1 changed file with 36 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,19 @@ namespace ir {
namespace {

template <typename DoEachT>
void VisitEachOp(pir::Operation* op, const DoEachT& DoEach) {
for (uint32_t i = 0; i < op->num_regions(); i++) {
for (pir::Block& block : op->region(i)) {
for (pir::Operation& sub_op : block) {
DoEach(sub_op);
if (sub_op.num_regions() > 0) {
VisitEachOp(&sub_op, DoEach);
}
}
}
void VisitEachOp(cinn::dialect::GroupOp op, const DoEachT& DoEach) {
for (pir::Operation* sub_op : op.GetOperators()) {
DoEach(sub_op);
}
}

template <typename DoEachT>
void VisitEachValue(const pir::Operation& op, const DoEachT& DoEach) {
for (std::size_t i = 0; i < op.num_operands(); ++i) {
DoEach(op.operand_source(i));
void VisitEachValue(const pir::Operation* op, const DoEachT& DoEach) {
for (std::size_t i = 0; i < op->num_operands(); ++i) {
DoEach(op->operand_source(i));
}
for (std::size_t i = 0; i < op.num_results(); ++i) {
DoEach(op.result(i));
for (std::size_t i = 0; i < op->num_results(); ++i) {
DoEach(op->result(i));
}
}

Expand Down Expand Up @@ -99,6 +92,22 @@ symbol::ShapeOrDataDimExprs SubstituteShapeOrData(
return std::visit(lambdas, shape_or_data.variant());
}

int GetDimExprPriority(const symbol::DimExpr& dim_expr) {
return std::visit(
symbol::Overloaded{
[&](std::int64_t) { return 0; },
[&](const std::string&) { return 1; },
[&](const symbol::Negative<symbol::DimExpr>&) { return 2; },
[&](const symbol::Reciprocal<symbol::DimExpr>&) { return 2; },
[&](const symbol::Add<symbol::DimExpr>&) { return 2; },
[&](const symbol::Mul<symbol::DimExpr>&) { return 2; },
[&](const symbol::Max<symbol::DimExpr>&) { return 2; },
[&](const symbol::Min<symbol::DimExpr>&) { return 2; },
[&](const symbol::Broadcast<symbol::DimExpr>&) { return 2; },
},
dim_expr.variant());
}

std::unordered_map<symbol::DimExpr, symbol::DimExpr> GetDimExprSubstitution(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
const std::vector<symbol::DimExprConstraint>& dim_expr_constraints =
Expand All @@ -123,9 +132,8 @@ std::unordered_map<symbol::DimExpr, symbol::DimExpr> GetDimExprSubstitution(
CHECK(!dim_expr_cluster.empty());
auto dim_expr_root = dim_expr_cluster[0];
for (const auto& dim_expr : dim_expr_cluster) {
if (std::holds_alternative<std::int64_t>(dim_expr)) {
if (GetDimExprPriority(dim_expr) < GetDimExprPriority(dim_expr_root)) {
dim_expr_root = dim_expr;
break;
}
}
for (const auto& dim_expr : dim_expr_cluster) {
Expand All @@ -137,40 +145,39 @@ std::unordered_map<symbol::DimExpr, symbol::DimExpr> GetDimExprSubstitution(
return substitution_pattern;
}

void SubstituteDimExprBasedOnConstraints(pir::Operation* module_op) {
void SubstituteDimExprBasedOnConstraints(pir::Operation* op) {
VLOG(4) << "SubstituteDimExprBasedOnConstraints start";
auto group_op = op->dyn_cast<cinn::dialect::GroupOp>();
pir::ShapeConstraintIRAnalysis* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(
module_op->dyn_cast<pir::ModuleOp>().program());
&pir::ShapeAnalysisManager::Instance().Get(group_op->GetParentProgram());
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
substitution_pattern = GetDimExprSubstitution(shape_analysis);

VisitEachOp(module_op, [&](pir::Operation& op) {
VisitEachOp(group_op, [&](pir::Operation* op) {
VisitEachValue(op, [&](pir::Value value) {
if (!shape_analysis->HasShapeOrDataForValue(value)) {
VLOG(4) << "Can not find ShapeOrData for value of op(" << op.name()
VLOG(4) << "Can not find ShapeOrData for value of op(" << op->name()
<< ") in shape_analysis";
} else {
const symbol::ShapeOrDataDimExprs& origin_shape_or_data =
shape_analysis->GetShapeOrDataForValue(value);
VLOG(8) << op.name()
VLOG(8) << op->name()
<< " origin_shape_or_data: " << origin_shape_or_data;
const symbol::ShapeOrDataDimExprs& substituted_shape_or_data =
SubstituteShapeOrData(origin_shape_or_data, substitution_pattern);
VLOG(8) << op.name()
VLOG(8) << op->name()
<< " substituted_shape_or_data: " << substituted_shape_or_data;
shape_analysis->SetShapeOrDataForValue(value,
substituted_shape_or_data);
}
});
if (op.num_results() > 0) {
if (op->num_results() > 0) {
pir::shape::SetShapeAttrForOp(
&op, shape_analysis->GetShapeOrDataForValue(op.result(0)));
op, shape_analysis->GetShapeOrDataForValue(op->result(0)));
} else {
pir::shape::SetShapeAttrForOp(
&op, shape_analysis->GetShapeOrDataForValue(op.operand_source(0)));
op, shape_analysis->GetShapeOrDataForValue(op->operand_source(0)));
}
// TODO(JiaWenxuan): substitute the attribute "sym_shape_str" of the op
});
VLOG(4) << "SubstituteDimExprBasedOnConstraints end";
}
Expand All @@ -185,7 +192,7 @@ class SubstituteDimExprBasedOnConstraintsPass : public pir::Pass {
}

bool CanApplyOn(pir::Operation* op) const override {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
return op->isa<cinn::dialect::GroupOp>() && op->num_regions() > 0;
}
};

Expand Down

0 comments on commit b4d2c6c

Please sign in to comment.