Skip to content

MLIR 源码分析:remove-shape-constraints #7

@xtyi

Description

@xtyi

Pass 示例

func.func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
  %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
  %1 = shape.assuming %0 -> index {
    %2 = "test.source"() : () -> (index)
    shape.assuming_yield %2 : index
  }
  return %1 : index
}
func.func @f(%arg0: !shape.shape, %arg1: !shape.shape) -> index {
  %0 = shape.const_witness true
  %1 = shape.assuming %0 -> (index) {
    %2 = "test.source"() : () -> index
    shape.assuming_yield %2 : index
  }
  return %1 : index
}
func.func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
  %0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
  %1 = shape.assuming %0 -> index {
    %2 = "test.source"() : () -> (index)
    shape.assuming_yield %2 : index
  }
  return %1 : index
}
func.func @f(%arg0: !shape.shape, %arg1: !shape.shape) -> index {
  %0 = shape.const_witness true
  %1 = shape.assuming %0 -> (index) {
    %2 = "test.source"() : () -> index
    shape.assuming_yield %2 : index
  }
  return %1 : index
}
func.func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
  %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
  %1 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
  %2 = shape.assuming_all %0, %1
  %3 = shape.assuming %0 -> index {
    %4 = "test.source"() : () -> (index)
    shape.assuming_yield %4 : index
  }
  return %3 : index
}
func.func @f(%arg0: !shape.shape, %arg1: !shape.shape) -> index {
  %0 = shape.const_witness true
  %1 = shape.assuming %0 -> (index) {
    %2 = "test.source"() : () -> index
    shape.assuming_yield %2 : index
  }
  return %1 : index
}

可以看到,该 Pass 做的基本上是把 shape.cstr_ 相关指令转换为 shape.const_witness true

def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> {
  let summary = "Replace all cstr_ ops with a true witness";
  let constructor = "mlir::createRemoveShapeConstraintsPass()";
}
class RemoveShapeConstraintsPass
    : public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {

  void runOnOperation() override {
    MLIRContext &ctx = getContext();

    RewritePatternSet patterns(&ctx);
    populateRemoveShapeConstraintsPatterns(patterns);

    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
  }
};

void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) {
  patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
      patterns.getContext());
}

可以看到主要就是两个 Pattern

class RemoveCstrBroadcastableOp
    : public OpRewritePattern<shape::CstrBroadcastableOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
                                PatternRewriter &rewriter) const override {
    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
    return success();
  }
};

class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(shape::CstrEqOp op,
                                PatternRewriter &rewriter) const override {
    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
    return success();
  }
};

可以看到就是把 CstrBroadcastableOpCstrEqOp 换成 ConstWitnessOp

TODO: 这个 Pass 有什么作用? shape.cstr op 和 assuming op 的作用

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions