Skip to content

Commit

Permalink
[CINN]update dyshape workflow (#62101)
Browse files Browse the repository at this point in the history
* update dyshape workflow

* update

* polish code

* poslish code

* fix compiler bug
  • Loading branch information
phlrain authored Feb 29, 2024
1 parent 4865fed commit 4448d45
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ void ApplyCinnPreprocessPass(

pass_manager->AddPass(
cinn::dialect::ir::CreateMoveGenerateShapeOpsToProloguePass());
pass_manager->AddPass(cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass());
pass_manager->AddPass(cinn::dialect::ir::CreateDynamicReshapeOpPass());
pass_manager->AddPass(cinn::dialect::ir::CreateReplaceDynamicExpandOpPass());
pass_manager->AddPass(cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass());
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());

pass_manager->Run(program);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class DynamicReshapeOpPass : public pir::Pass {
for (uint32_t i = 0; i < op->num_regions(); ++i) {
for (auto& block : op->region(i)) {
for (auto& op : block) {
if (op.isa<cinn::dialect::FusionOp>()) {
if (op.isa<cinn::dialect::GroupOp>()) {
auto [_, num_rewrites] =
pir::ApplyPatternsGreedily(&op, patterns_, cfg);
AddStatistics(num_rewrites);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,28 @@ class DynamicExpandOpPattern
for (size_t i = 0; i < x_rank; ++i) {
broadcast_axes[i] = i + index_gap;
}
std::vector<int64_t> out_shape(out_rank, -1);

pir::ShapeConstraintIRAnalysis& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());

const auto& UpdateOutputShapeByDimExpr = [&]() -> std::vector<int64_t> {
std::vector<int64_t> out_shape(out_rank, -1);
if (shape_analysis.HasShapeOrDataForValue(op->result(0))) {
VLOG(3) << "found shape dialect";
auto shape_info =
shape_analysis.GetShapeOrDataForValue(op->result(0)).shape();

for (size_t i = 0; i < shape_info.size(); ++i) {
if (shape_info[i].isa<int64_t>()) {
out_shape[i] = shape_info[i].Get<int64_t>();
}
}
}
return out_shape;
};

auto out_shape = UpdateOutputShapeByDimExpr();

return rewriter.Build<cinn::dialect::BroadcastOp>(
op->operand_source(0), broadcast_axes, out_shape);
}();
Expand Down Expand Up @@ -91,7 +112,7 @@ class ReplaceDynamicExpandOpPass : public pir::Pass {
for (uint32_t i = 0; i < op->num_regions(); ++i) {
for (auto& block : op->region(i)) {
for (auto& op : block) {
if (op.isa<cinn::dialect::FusionOp>()) {
if (op.isa<cinn::dialect::GroupOp>()) {
const auto& [_, num_rewrites] =
pir::ApplyPatternsGreedily(&op, patterns_, cfg);
AddStatistics(num_rewrites);
Expand Down

0 comments on commit 4448d45

Please sign in to comment.