Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamic Shape]Fix SubstituteDimExprBasedOnConstraintsPass invalid bug #62570

Merged
merged 3 commits into from
Mar 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/cinn/common/dim_expr_util.h"
#include "paddle/cinn/common/union_find.h"
#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h"

namespace cinn {
namespace dialect {
Expand All @@ -27,26 +28,19 @@ namespace ir {
namespace {

template <typename DoEachT>
void VisitEachOp(pir::Operation* op, const DoEachT& DoEach) {
for (uint32_t i = 0; i < op->num_regions(); i++) {
for (pir::Block& block : op->region(i)) {
for (pir::Operation& sub_op : block) {
DoEach(sub_op);
if (sub_op.num_regions() > 0) {
VisitEachOp(&sub_op, DoEach);
}
}
}
void VisitEachOp(cinn::dialect::GroupOp op, const DoEachT& DoEach) {
for (pir::Operation* sub_op : op.GetOperators()) {
DoEach(sub_op);
}
}

template <typename DoEachT>
void VisitEachValue(const pir::Operation& op, const DoEachT& DoEach) {
for (std::size_t i = 0; i < op.num_operands(); ++i) {
DoEach(op.operand_source(i));
void VisitEachValue(const pir::Operation* op, const DoEachT& DoEach) {
for (std::size_t i = 0; i < op->num_operands(); ++i) {
DoEach(op->operand_source(i));
}
for (std::size_t i = 0; i < op.num_results(); ++i) {
DoEach(op.result(i));
for (std::size_t i = 0; i < op->num_results(); ++i) {
DoEach(op->result(i));
}
}

Expand All @@ -60,8 +54,9 @@ symbol::TensorShapeOrDataDimExprs SubstituteTensorShapeOrData(
substitution_pattern) -> std::vector<symbol::DimExpr> {
std::vector<symbol::DimExpr> substituted_dim_expr{};
for (const symbol::DimExpr& dim_expr : original_dim_expr) {
substituted_dim_expr.push_back(
cinn::common::SubstituteDimExpr(dim_expr, substitution_pattern));
const auto& tmp_dim_expr =
cinn::common::SubstituteDimExpr(dim_expr, substitution_pattern);
substituted_dim_expr.push_back(symbol::SimplifyDimExpr(tmp_dim_expr));
}
return substituted_dim_expr;
};
Expand Down Expand Up @@ -99,6 +94,22 @@ symbol::ShapeOrDataDimExprs SubstituteShapeOrData(
return std::visit(lambdas, shape_or_data.variant());
}

int GetDimExprPriority(const symbol::DimExpr& dim_expr) {
return std::visit(
symbol::Overloaded{
[&](std::int64_t) { return 0; },
[&](const std::string&) { return 1; },
[&](const symbol::Negative<symbol::DimExpr>&) { return 2; },
[&](const symbol::Reciprocal<symbol::DimExpr>&) { return 2; },
[&](const symbol::Add<symbol::DimExpr>&) { return 2; },
[&](const symbol::Mul<symbol::DimExpr>&) { return 2; },
[&](const symbol::Max<symbol::DimExpr>&) { return 2; },
[&](const symbol::Min<symbol::DimExpr>&) { return 2; },
[&](const symbol::Broadcast<symbol::DimExpr>&) { return 2; },
},
dim_expr.variant());
}

std::unordered_map<symbol::DimExpr, symbol::DimExpr> GetDimExprSubstitution(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
const std::vector<symbol::DimExprConstraint>& dim_expr_constraints =
Expand All @@ -123,9 +134,8 @@ std::unordered_map<symbol::DimExpr, symbol::DimExpr> GetDimExprSubstitution(
CHECK(!dim_expr_cluster.empty());
auto dim_expr_root = dim_expr_cluster[0];
for (const auto& dim_expr : dim_expr_cluster) {
if (std::holds_alternative<std::int64_t>(dim_expr)) {
if (GetDimExprPriority(dim_expr) < GetDimExprPriority(dim_expr_root)) {
dim_expr_root = dim_expr;
break;
}
}
for (const auto& dim_expr : dim_expr_cluster) {
Expand All @@ -137,40 +147,39 @@ std::unordered_map<symbol::DimExpr, symbol::DimExpr> GetDimExprSubstitution(
return substitution_pattern;
}

void SubstituteDimExprBasedOnConstraints(pir::Operation* module_op) {
void SubstituteDimExprBasedOnConstraints(pir::Operation* op) {
VLOG(4) << "SubstituteDimExprBasedOnConstraints start";
auto group_op = op->dyn_cast<cinn::dialect::GroupOp>();
pir::ShapeConstraintIRAnalysis* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(
module_op->dyn_cast<pir::ModuleOp>().program());
&pir::ShapeAnalysisManager::Instance().Get(group_op->GetParentProgram());
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
substitution_pattern = GetDimExprSubstitution(shape_analysis);

VisitEachOp(module_op, [&](pir::Operation& op) {
VisitEachOp(group_op, [&](pir::Operation* op) {
VisitEachValue(op, [&](pir::Value value) {
if (!shape_analysis->HasShapeOrDataForValue(value)) {
VLOG(4) << "Can not find ShapeOrData for value of op(" << op.name()
VLOG(4) << "Can not find ShapeOrData for value of op(" << op->name()
<< ") in shape_analysis";
} else {
const symbol::ShapeOrDataDimExprs& origin_shape_or_data =
shape_analysis->GetShapeOrDataForValue(value);
VLOG(8) << op.name()
VLOG(8) << op->name()
<< " origin_shape_or_data: " << origin_shape_or_data;
const symbol::ShapeOrDataDimExprs& substituted_shape_or_data =
SubstituteShapeOrData(origin_shape_or_data, substitution_pattern);
VLOG(8) << op.name()
VLOG(8) << op->name()
<< " substituted_shape_or_data: " << substituted_shape_or_data;
shape_analysis->SetShapeOrDataForValue(value,
substituted_shape_or_data);
}
});
if (op.num_results() > 0) {
if (op->num_results() > 0) {
pir::shape::SetShapeAttrForOp(
&op, shape_analysis->GetShapeOrDataForValue(op.result(0)));
op, shape_analysis->GetShapeOrDataForValue(op->result(0)));
} else {
pir::shape::SetShapeAttrForOp(
&op, shape_analysis->GetShapeOrDataForValue(op.operand_source(0)));
op, shape_analysis->GetShapeOrDataForValue(op->operand_source(0)));
}
// TODO(JiaWenxuan): substitute the attribute "sym_shape_str" of the op
});
VLOG(4) << "SubstituteDimExprBasedOnConstraints end";
}
Expand All @@ -185,7 +194,7 @@ class SubstituteDimExprBasedOnConstraintsPass : public pir::Pass {
}

bool CanApplyOn(pir::Operation* op) const override {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
return op->isa<cinn::dialect::GroupOp>() && op->num_regions() > 0;
}
};

Expand Down