Skip to content

Commit 62a373c

Browse files
[MLIR] Add sincos fusion pass
We see performance improvements from using sincos to reuse calculations in hot loops that compute sin() and cos() on the same operand. Add a pass to identify sin() and cos() calls in the same block with the same operand and fast-math flags, and fuse them into a sincos op. Follow-up to: * llvm#160561 * llvm#160772
1 parent 82efd72 commit 62a373c

File tree

4 files changed

+142
-0
lines changed

4 files changed

+142
-0
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,12 @@ def MathExpandOpsPass : Pass<"math-expand-ops"> {
6464
];
6565
}
6666

67+
def MathSincosFusionPass : Pass<"math-sincos-fusion"> {
68+
let summary = "Fuse sin and cos operations.";
69+
let description = [{
70+
Fuse sin and cos operations into a sincos operation.
71+
}];
72+
let dependentDialects = ["math::MathDialect"];
73+
}
74+
6775
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES

mlir/lib/Dialect/Math/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRMathTransforms
33
ExpandOps.cpp
44
ExtendToSupportedTypes.cpp
55
PolynomialApproximation.cpp
6+
SincosFusion.cpp
67
UpliftToFMA.cpp
78

89
ADDITIONAL_HEADER_DIRS
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
//===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Math/IR/Math.h"
10+
#include "mlir/Dialect/Math/Transforms/Passes.h"
11+
#include "mlir/IR/PatternMatch.h"
12+
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14+
15+
using namespace mlir;
16+
using namespace mlir::math;
17+
18+
namespace {
19+
20+
/// Fuse a math.sin and math.cos in the same block that use the same operand and
21+
/// have identical fastmath flags into a single math.sincos.
22+
struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
23+
using OpRewritePattern<math::SinOp>::OpRewritePattern;
24+
25+
LogicalResult matchAndRewrite(math::SinOp sinOp,
26+
PatternRewriter &rewriter) const override {
27+
Value operand = sinOp.getOperand();
28+
auto sinFastMathFlags = sinOp.getFastmath();
29+
30+
math::CosOp cosOp = nullptr;
31+
sinOp->getBlock()->walk([&](math::CosOp op) {
32+
if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags)
33+
cosOp = op;
34+
});
35+
36+
if (!cosOp)
37+
return failure();
38+
39+
Type elemType = sinOp.getType();
40+
auto sincos = rewriter.create<math::SincosOp>(
41+
sinOp.getLoc(), TypeRange{elemType, elemType}, operand,
42+
sinOp.getFastmathAttr());
43+
44+
rewriter.replaceOp(sinOp, sincos.getSin());
45+
rewriter.replaceOp(cosOp, sincos.getCos());
46+
return success();
47+
}
48+
};
49+
50+
} // namespace
51+
52+
namespace mlir::math {
53+
#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS
54+
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
55+
} // namespace mlir::math
56+
57+
namespace {
58+
59+
struct MathSincosFusionPass final
60+
: math::impl::MathSincosFusionPassBase<MathSincosFusionPass> {
61+
using MathSincosFusionPassBase::MathSincosFusionPassBase;
62+
63+
void runOnOperation() override {
64+
RewritePatternSet patterns(&getContext());
65+
patterns.add<SincosFusionPattern>(&getContext());
66+
67+
GreedyRewriteConfig config;
68+
if (failed(
69+
applyPatternsGreedily(getOperation(), std::move(patterns), config)))
70+
return signalPassFailure();
71+
}
72+
};
73+
74+
} // namespace
75+
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: mlir-opt -math-sincos-fusion %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @sincos_fusion(
4+
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
5+
// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
6+
// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32
7+
// CHECK: }
8+
func.func @sincos_fusion(%arg0 : f32) -> (f32, f32) {
9+
%0 = math.sin %arg0 : f32
10+
%1 = math.cos %arg0 : f32
11+
func.return %0, %1 : f32, f32
12+
}
13+
14+
// CHECK-LABEL: func.func @sincos_fusion_no_match_fmf(
15+
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
16+
// CHECK: %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath<contract> : f32
17+
// CHECK: %[[VAL_1:.*]] = math.cos %[[ARG0]] : f32
18+
// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32
19+
// CHECK: }
20+
func.func @sincos_fusion_no_match_fmf(%arg0 : f32) -> (f32, f32) {
21+
%0 = math.sin %arg0 fastmath<contract> : f32
22+
%1 = math.cos %arg0 : f32
23+
func.return %0, %1 : f32, f32
24+
}
25+
26+
// CHECK-LABEL: func.func @sincos_no_fusion_different_block(
27+
// CHECK-SAME: %[[ARG0:.*]]: f32,
28+
// CHECK-SAME: %[[ARG1:.*]]: i1) -> f32 {
29+
// CHECK: %[[VAL_0:.*]] = scf.if %[[ARG1]] -> (f32) {
30+
// CHECK: %[[VAL_1:.*]] = math.sin %[[ARG0]] : f32
31+
// CHECK: scf.yield %[[VAL_1]] : f32
32+
// CHECK: } else {
33+
// CHECK: %[[VAL_2:.*]] = math.cos %[[ARG0]] : f32
34+
// CHECK: scf.yield %[[VAL_2]] : f32
35+
// CHECK: }
36+
// CHECK: return %[[VAL_0]] : f32
37+
// CHECK: }
38+
func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 {
39+
%0 = scf.if %flag -> f32 {
40+
%s = math.sin %arg0 : f32
41+
scf.yield %s : f32
42+
} else {
43+
%c = math.cos %arg0 : f32
44+
scf.yield %c : f32
45+
}
46+
func.return %0 : f32
47+
}
48+
49+
// CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath(
50+
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
51+
// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32
52+
// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32
53+
// CHECK: }
54+
func.func @sincos_fusion_preserve_fastmath(%arg0 : f32) -> (f32, f32) {
55+
%0 = math.sin %arg0 fastmath<contract> : f32
56+
%1 = math.cos %arg0 fastmath<contract> : f32
57+
func.return %0, %1 : f32, f32
58+
}

0 commit comments

Comments
 (0)