@@ -91,40 +91,34 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
91
91
}
92
92
93
93
// / Expands tanh op into
94
- // / 1-exp^{-2x} / 1+exp^{-2x}
95
- // / To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
96
- // / We compute a "signs" value which is -1 if input is negative and +1 if input
97
- // / is positive. Then multiply the input by this value, guaranteeing that the
98
- // / result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
99
- // / 1]. Expand the computation on the input `x * sign(x)`, then multiply the
100
- // / result by `sign(x)` to retain sign of the real result.
94
+ // / 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
95
+ // / 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
101
96
static LogicalResult convertTanhOp (math::TanhOp op, PatternRewriter &rewriter) {
102
97
auto floatType = op.getOperand ().getType ();
103
98
Location loc = op.getLoc ();
104
- Value zero = createFloatConst (loc, floatType, 0.0 , rewriter);
105
99
Value one = createFloatConst (loc, floatType, 1.0 , rewriter);
106
- Value negTwo = createFloatConst (loc, floatType, -2.0 , rewriter);
107
-
108
- // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
109
- Value sign = rewriter.create <arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
110
- op.getOperand (), zero);
111
- sign = rewriter.create <arith::SIToFPOp>(loc, floatType, sign);
112
- sign = rewriter.create <arith::MulFOp>(loc, sign, negTwo);
113
- sign = rewriter.create <arith::AddFOp>(loc, sign, one);
100
+ Value two = createFloatConst (loc, floatType, 2.0 , rewriter);
101
+ Value doubledX = rewriter.create <arith::MulFOp>(loc, op.getOperand (), two);
114
102
115
- // Normalize input to positive value: y = sign(x) * x
116
- Value positiveX = rewriter.create <arith::MulFOp>(loc, sign, op.getOperand ());
117
-
118
- // Decompose on normalized input
119
- Value negDoubledX = rewriter.create <arith::MulFOp>(loc, negTwo, positiveX);
103
+ // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
104
+ Value negDoubledX = rewriter.create <arith::NegFOp>(loc, doubledX);
120
105
Value exp2x = rewriter.create <math::ExpOp>(loc, negDoubledX);
121
106
Value dividend = rewriter.create <arith::SubFOp>(loc, one, exp2x);
122
107
Value divisor = rewriter.create <arith::AddFOp>(loc, one, exp2x);
123
108
Value positiveRes = rewriter.create <arith::DivFOp>(loc, dividend, divisor);
124
109
125
- // Multiply result by sign(x) to retain signs from negative inputs
126
- rewriter.replaceOpWithNewOp <arith::MulFOp>(op, sign, positiveRes);
110
+ // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
111
+ exp2x = rewriter.create <math::ExpOp>(loc, doubledX);
112
+ dividend = rewriter.create <arith::SubFOp>(loc, exp2x, one);
113
+ divisor = rewriter.create <arith::AddFOp>(loc, exp2x, one);
114
+ Value negativeRes = rewriter.create <arith::DivFOp>(loc, dividend, divisor);
127
115
116
+ // tanh(x) = x >= 0 ? positiveRes : negativeRes
117
+ Value zero = createFloatConst (loc, floatType, 0.0 , rewriter);
118
+ Value cmpRes = rewriter.create <arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
119
+ op.getOperand (), zero);
120
+ rewriter.replaceOpWithNewOp <arith::SelectOp>(op, cmpRes, positiveRes,
121
+ negativeRes);
128
122
return success ();
129
123
}
130
124
0 commit comments