Skip to content

Commit a267ca7

Browse files
authored
Improve performance of x**y when y is a huge value (#438)
When y.exponent is several thousand or more, x**y was slow because exponentiation by squaring requires several thousands of multiplications. Use exp and log in such case. Needed to calaculate (1+1/n).power(n, prec)
1 parent cb2458b commit a267ca7

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

lib/bigdecimal.rb

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -165,22 +165,25 @@ def power(y, prec = nil)
165165
return BigDecimal(1).div(inv, prec)
166166
end
167167

168-
int_part = y.fix.to_i
169168
prec2 = prec + BigDecimal.double_fig
170-
pow_prec = prec2 + (int_part > 0 ? y.exponent : 0)
171-
ans = BigDecimal(1)
172-
n = 1
173-
xn = x
174-
while true
175-
ans = ans.mult(xn, pow_prec) if int_part.allbits?(n)
176-
n <<= 1
177-
break if n > int_part
178-
xn = xn.mult(xn, pow_prec)
179-
end
180-
unless frac_part.zero?
181-
ans = ans.mult(BigMath.exp(BigMath.log(x, prec2).mult(frac_part, prec2), prec2), prec2)
169+
170+
if frac_part.zero? && y.exponent < Math.log(prec) * 5 + 20
171+
# Use exponentiation by squaring if y is an integer and not too large
172+
pow_prec = prec2 + y.exponent
173+
n = 1
174+
xn = x
175+
ans = BigDecimal(1)
176+
int_part = y.fix.to_i
177+
while true
178+
ans = ans.mult(xn, pow_prec) if int_part.allbits?(n)
179+
n <<= 1
180+
break if n > int_part
181+
xn = xn.mult(xn, pow_prec)
182+
end
183+
ans.mult(1, prec)
184+
else
185+
BigMath.exp(BigMath.log(x, prec2).mult(y, prec2), prec)
182186
end
183-
ans.mult(1, prec)
184187
end
185188

186189
# Returns the square root of the value.

test/bigdecimal/test_bigdecimal.rb

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,11 @@ def test_power_with_rational
19911991
assert_in_epsilon(z2, x2 ** y, 1e-99)
19921992
end
19931993

1994+
def test_power_with_huge_value
1995+
n = BigDecimal('7e+10000')
1996+
assert_equal(BigMath.exp(1, 100), (1 + BigDecimal(1).div(n, 120)).power(n, 100))
1997+
end
1998+
19941999
def test_power_precision
19952000
x = BigDecimal("1.41421356237309504880168872420969807856967187537695")
19962001
y = BigDecimal("3.14159265358979323846264338327950288419716939937511")

0 commit comments

Comments
 (0)