Skip to content

Commit

Permalink
Support more sqrt->rsqrt rewrite (#566)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #564 

## Possible side effects?

- Performance: better

- Backward compatibility: n/a
  • Loading branch information
anakinxc authored Feb 20, 2024
1 parent e9d8933 commit d9c2224
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
37 changes: 19 additions & 18 deletions libspu/compiler/passes/rewrite_div_sqrt_patterns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <limits>

#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -70,24 +67,28 @@ struct DivRewriter : public OpRewritePattern<DivOp> {
// Pattern 2:
// y/(k*sqrt(x)) -> y/k*rsqrt(x)
if (auto mulOp = denominator.getDefiningOp<MulOp>()) {
auto sqrtOp = mulOp.getRhs().getDefiningOp<SqrtOp>();
auto k = mulOp.getLhs();
if (sqrtOp == nullptr) {
sqrtOp = mulOp.getLhs().getDefiningOp<SqrtOp>();
Value k;
Operation *rsqrt =
rewriteSqrtIfPossible(rewriter, mulOp.getLhs().getDefiningOp());
if (rsqrt != nullptr) {
k = mulOp.getRhs();
} else {
k = mulOp.getLhs();
rsqrt =
rewriteSqrtIfPossible(rewriter, mulOp.getRhs().getDefiningOp());
}
if (sqrtOp) {
// y/k
auto newDiv = rewriter.create<DivOp>(
op.getLoc(), op->getResultTypes(), op.getLhs(), k);
// rsqrt(x)
auto newRsqrt = rewriter.create<RsqrtOp>(
op->getLoc(), sqrtOp->getResultTypes(), sqrtOp->getOperand(0));
// y/k*rsqrt(x)
rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(), newDiv,
newRsqrt);
return success();

// No 1/sqrt -> rsqrt rewrite, bailout
if (rsqrt == nullptr) {
return failure();
}

auto newDiv = rewriter.create<DivOp>(op.getLoc(), op->getResultTypes(),
op.getLhs(), k);
rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(), newDiv,
rsqrt->getResult(0));

return success();
}
}
return failure();
Expand Down
13 changes: 13 additions & 0 deletions libspu/compiler/tests/optimizations/optimize_sqrt_to_rsqrt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,16 @@ func.func @main(%arg0: tensor<3x4x!pphlo.secret<i32>>) -> tensor<3x4x!pphlo.secr
%7 = pphlo.divide %0, %6 : tensor<3x4x!pphlo.secret<f32>>
return %7 : tensor<3x4x!pphlo.secret<f32>>
}

// -----

func.func @main(%arg0: tensor<3x4x!pphlo.secret<f32>>, %arg1: tensor<5x6x!pphlo.secret<f32>>) -> tensor<3x3x4x!pphlo.secret<f32>> {
// CHECK-NOT: pphlo.sqrt
// CHECK: pphlo.rsqrt
%0 = pphlo.broadcast %arg0, dims = [1, 2] : (tensor<3x4x!pphlo.secret<f32>>) -> tensor<3x3x4x!pphlo.secret<f32>>
%1 = pphlo.sqrt %arg0 : tensor<3x4x!pphlo.secret<f32>>
%2 = pphlo.broadcast %1, dims = [0, 2] : (tensor<3x4x!pphlo.secret<f32>>) -> tensor<3x3x4x!pphlo.secret<f32>>
%3 = pphlo.multiply %0, %2 : tensor<3x3x4x!pphlo.secret<f32>>
%4 = pphlo.divide %0, %3 : tensor<3x3x4x!pphlo.secret<f32>>
return %4 : tensor<3x3x4x!pphlo.secret<f32>>
}

0 comments on commit d9c2224

Please sign in to comment.