-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
Description
Pass 示例
func.func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
%1 = shape.assuming %0 -> index {
%2 = "test.source"() : () -> (index)
shape.assuming_yield %2 : index
}
return %1 : index
}func.func @f(%arg0: !shape.shape, %arg1: !shape.shape) -> index {
%0 = shape.const_witness true
%1 = shape.assuming %0 -> (index) {
%2 = "test.source"() : () -> index
shape.assuming_yield %2 : index
}
return %1 : index
}func.func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
%0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
%1 = shape.assuming %0 -> index {
%2 = "test.source"() : () -> (index)
shape.assuming_yield %2 : index
}
return %1 : index
}func.func @f(%arg0: !shape.shape, %arg1: !shape.shape) -> index {
%0 = shape.const_witness true
%1 = shape.assuming %0 -> (index) {
%2 = "test.source"() : () -> index
shape.assuming_yield %2 : index
}
return %1 : index
}func.func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
%1 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
%2 = shape.assuming_all %0, %1
%3 = shape.assuming %0 -> index {
%4 = "test.source"() : () -> (index)
shape.assuming_yield %4 : index
}
return %3 : index
}func.func @f(%arg0: !shape.shape, %arg1: !shape.shape) -> index {
%0 = shape.const_witness true
%1 = shape.assuming %0 -> (index) {
%2 = "test.source"() : () -> index
shape.assuming_yield %2 : index
}
return %1 : index
}可以看到,该 Pass 做的基本上是把 shape.cstr_ 相关指令转换为 shape.const_witness true
def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> {
let summary = "Replace all cstr_ ops with a true witness";
let constructor = "mlir::createRemoveShapeConstraintsPass()";
}
class RemoveShapeConstraintsPass
: public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
void runOnOperation() override {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
populateRemoveShapeConstraintsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) {
patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
patterns.getContext());
}可以看到主要就是两个 Pattern
class RemoveCstrBroadcastableOp
: public OpRewritePattern<shape::CstrBroadcastableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
return success();
}
};
class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrEqOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
return success();
}
};可以看到就是把 CstrBroadcastableOp 和 CstrEqOp 换成 ConstWitnessOp
TODO: 这个 Pass 有什么作用? shape.cstr op 和 assuming op 的作用
Reactions are currently unavailable