Skip to content

Commit dc74d2f

Browse files
authored
[MLIR][NFC] Retire let constructor for Shape and MLProgram (#128869)
`let constructor` is legacy (do not use in tree!) since the table gen backend emits most of the glue logic to build a pass. This PR retires the td method for Shape and MLProgram
1 parent 7521207 commit dc74d2f

File tree

8 files changed

+16
-47
lines changed

8 files changed

+16
-47
lines changed

mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h

-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ namespace ml_program {
2323
// Registration
2424
//===----------------------------------------------------------------------===//
2525

26-
std::unique_ptr<OperationPass<ModuleOp>> createMLProgramPipelineGlobalsPass();
27-
2826
/// Generate the code for registering passes.
2927
#define GEN_PASS_REGISTRATION
3028
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"

mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
14+
def MLProgramPipelineGlobalsPass
15+
: Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
1516
let summary = "Optimize `ml_program` global operations for read and store";
1617
let description = [{
1718
`ml_program`'s load and store operations can be optimized for
@@ -21,7 +22,6 @@ def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
2122
The pass is designed to handle both nested regions and function calls
2223
safely.
2324
}];
24-
let constructor = "mlir::ml_program::createMLProgramPipelineGlobalsPass()";
2525
}
2626

2727
#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES

mlir/include/mlir/Dialect/Shape/Transforms/Passes.h

-10
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@ namespace mlir {
3030
#define GEN_PASS_DECL
3131
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
3232

33-
/// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape
34-
/// dialect to be convertible to Arith. For example, `shape.num_elements` get
35-
/// transformed to `shape.reduce`, which can be lowered to SCF and Arith.
36-
std::unique_ptr<Pass> createShapeToShapeLowering();
37-
3833
/// Collects a set of patterns to rewrite ops within the Shape dialect.
3934
void populateShapeRewritePatterns(RewritePatternSet &patterns);
4035

@@ -45,11 +40,6 @@ void populateShapeRewritePatterns(RewritePatternSet &patterns);
4540
//
4641
// After this pass, no cstr_ operations exist.
4742
void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns);
48-
std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass();
49-
50-
/// Outline the shape computation part by adding shape.func and populate
51-
/// conrresponding mapping infomation into ShapeMappingAnalysis.
52-
std::unique_ptr<OperationPass<ModuleOp>> createOutlineShapeComputationPass();
5343

5444
//===----------------------------------------------------------------------===//
5545
// Registration

mlir/include/mlir/Dialect/Shape/Transforms/Passes.td

+5-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> {
14+
def OutlineShapeComputationPass
15+
: Pass<"outline-shape-computation", "ModuleOp"> {
1516
let summary = "Using shape.func to preserve shape computation";
1617
let description = [{
1718
This pass outlines the shape computation part in high level IR by adding
@@ -89,18 +90,16 @@ def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> {
8990
// - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
9091
```
9192
}];
92-
let constructor = "mlir::createOutlineShapeComputationPass()";
9393
let dependentDialects = ["shape::ShapeDialect"];
9494
}
9595

96-
def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> {
96+
def RemoveShapeConstraintsPass
97+
: Pass<"remove-shape-constraints", "func::FuncOp"> {
9798
let summary = "Replace all cstr_ ops with a true witness";
98-
let constructor = "mlir::createRemoveShapeConstraintsPass()";
9999
}
100100

101-
def ShapeToShapeLowering : Pass<"shape-to-shape-lowering", "func::FuncOp"> {
101+
def ShapeToShapeLoweringPass : Pass<"shape-to-shape-lowering", "func::FuncOp"> {
102102
let summary = "Legalize Shape dialect to be convertible to Arith";
103-
let constructor = "mlir::createShapeToShapeLowering()";
104103
}
105104

106105
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES

mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp

+2-7
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616

1717
namespace mlir {
1818
namespace ml_program {
19-
#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
19+
#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS
2020
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
2121

2222
namespace {
2323

2424
class MLProgramPipelineGlobals
25-
: public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
25+
: public impl::MLProgramPipelineGlobalsPassBase<MLProgramPipelineGlobals> {
2626
public:
2727
void runOnOperation() override;
2828

@@ -224,10 +224,5 @@ void MLProgramPipelineGlobals::runOnOperation() {
224224

225225
} // namespace
226226

227-
std::unique_ptr<OperationPass<mlir::ModuleOp>>
228-
createMLProgramPipelineGlobalsPass() {
229-
return std::make_unique<MLProgramPipelineGlobals>();
230-
}
231-
232227
} // namespace ml_program
233228
} // namespace mlir

mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp

+3-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include <vector>
2424

2525
namespace mlir {
26-
#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
26+
#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATIONPASS
2727
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
2828
} // namespace mlir
2929

@@ -163,7 +163,8 @@ void constructShapeFunc(
163163
}
164164

165165
struct OutlineShapeComputationPass
166-
: public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
166+
: public impl::OutlineShapeComputationPassBase<
167+
OutlineShapeComputationPass> {
167168

168169
void runOnOperation() override;
169170

@@ -324,8 +325,3 @@ bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
324325
}
325326

326327
} // namespace
327-
328-
std::unique_ptr<OperationPass<ModuleOp>>
329-
mlir::createOutlineShapeComputationPass() {
330-
return std::make_unique<OutlineShapeComputationPass>();
331-
}

mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp

+2-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1515

1616
namespace mlir {
17-
#define GEN_PASS_DEF_REMOVESHAPECONSTRAINTS
17+
#define GEN_PASS_DEF_REMOVESHAPECONSTRAINTSPASS
1818
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
1919
} // namespace mlir
2020

@@ -47,7 +47,7 @@ class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
4747

4848
/// Removal pass.
4949
class RemoveShapeConstraintsPass
50-
: public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
50+
: public impl::RemoveShapeConstraintsPassBase<RemoveShapeConstraintsPass> {
5151

5252
void runOnOperation() override {
5353
MLIRContext &ctx = getContext();
@@ -65,8 +65,3 @@ void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) {
6565
patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
6666
patterns.getContext());
6767
}
68-
69-
std::unique_ptr<OperationPass<func::FuncOp>>
70-
mlir::createRemoveShapeConstraintsPass() {
71-
return std::make_unique<RemoveShapeConstraintsPass>();
72-
}

mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include "mlir/Transforms/DialectConversion.h"
1818

1919
namespace mlir {
20-
#define GEN_PASS_DEF_SHAPETOSHAPELOWERING
20+
#define GEN_PASS_DEF_SHAPETOSHAPELOWERINGPASS
2121
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
2222
} // namespace mlir
2323

@@ -59,7 +59,7 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
5959

6060
namespace {
6161
struct ShapeToShapeLowering
62-
: public impl::ShapeToShapeLoweringBase<ShapeToShapeLowering> {
62+
: public impl::ShapeToShapeLoweringPassBase<ShapeToShapeLowering> {
6363
void runOnOperation() override;
6464
};
6565
} // namespace
@@ -81,7 +81,3 @@ void ShapeToShapeLowering::runOnOperation() {
8181
void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
8282
patterns.add<NumElementsOpConverter>(patterns.getContext());
8383
}
84-
85-
std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
86-
return std::make_unique<ShapeToShapeLowering>();
87-
}

0 commit comments

Comments
 (0)