diff --git a/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc b/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc index 9d88ebe1..b8c05534 100644 --- a/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc +++ b/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include - #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -70,24 +67,28 @@ struct DivRewriter : public OpRewritePattern { // Pattern 2: // y/(k*sqrt(x)) -> y/k*rsqrt(x) if (auto mulOp = denominator.getDefiningOp()) { - auto sqrtOp = mulOp.getRhs().getDefiningOp(); - auto k = mulOp.getLhs(); - if (sqrtOp == nullptr) { - sqrtOp = mulOp.getLhs().getDefiningOp(); + Value k; + Operation *rsqrt = + rewriteSqrtIfPossible(rewriter, mulOp.getLhs().getDefiningOp()); + if (rsqrt != nullptr) { k = mulOp.getRhs(); + } else { + k = mulOp.getLhs(); + rsqrt = + rewriteSqrtIfPossible(rewriter, mulOp.getRhs().getDefiningOp()); } - if (sqrtOp) { - // y/k - auto newDiv = rewriter.create( - op.getLoc(), op->getResultTypes(), op.getLhs(), k); - // rsqrt(x) - auto newRsqrt = rewriter.create( - op->getLoc(), sqrtOp->getResultTypes(), sqrtOp->getOperand(0)); - // y/k*rsqrt(x) - rewriter.replaceOpWithNewOp(op, op.getType(), newDiv, - newRsqrt); - return success(); + + // No 1/sqrt -> rsqrt rewrite, bailout + if (rsqrt == nullptr) { + return failure(); } + + auto newDiv = rewriter.create(op.getLoc(), op->getResultTypes(), + op.getLhs(), k); + rewriter.replaceOpWithNewOp(op, op.getType(), newDiv, + rsqrt->getResult(0)); + + return success(); } } return failure(); diff --git a/libspu/compiler/tests/optimizations/optimize_sqrt_to_rsqrt.mlir b/libspu/compiler/tests/optimizations/optimize_sqrt_to_rsqrt.mlir index 3488d2e3..8ea3ec5b 100644 --- a/libspu/compiler/tests/optimizations/optimize_sqrt_to_rsqrt.mlir +++ b/libspu/compiler/tests/optimizations/optimize_sqrt_to_rsqrt.mlir @@ -60,3 +60,16 @@ func.func @main(%arg0: tensor<3x4x!pphlo.secret>) -> tensor<3x4x!pphlo.secr %7 = pphlo.divide %0, %6 : tensor<3x4x!pphlo.secret> return %7 : tensor<3x4x!pphlo.secret> } + +// ----- + +func.func @main(%arg0: tensor<3x4x!pphlo.secret>, %arg1: tensor<5x6x!pphlo.secret>) -> tensor<3x3x4x!pphlo.secret> { + // CHECK-NOT: pphlo.sqrt + // CHECK: pphlo.rsqrt + %0 = pphlo.broadcast %arg0, dims = [1, 2] : (tensor<3x4x!pphlo.secret>) -> tensor<3x3x4x!pphlo.secret> + %1 = pphlo.sqrt %arg0 : tensor<3x4x!pphlo.secret> + %2 = pphlo.broadcast %1, dims = [0, 2] : (tensor<3x4x!pphlo.secret>) -> tensor<3x3x4x!pphlo.secret> + %3 = pphlo.multiply %0, %2 : tensor<3x3x4x!pphlo.secret> + %4 = pphlo.divide %0, %3 : tensor<3x3x4x!pphlo.secret> + return %4 : tensor<3x3x4x!pphlo.secret> + }