Skip to content

Commit

Permalink
pythongh-118610: Centralize power caching in _pylong.py (python#118611
Browse files Browse the repository at this point in the history
)

A new `compute_powers()` function computes all and only the powers of the base the various base-conversion functions need, as efficiently as reasonably possible (turns out that invoking `**`is needed at most once). This typically gives a few % speedup, but the primary point is to simplify the base-conversion functions, which no longer need their own, ad hoc, and less efficient power-caching schemes.

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
  • Loading branch information
tim-one and serhiy-storchaka authored May 8, 2024
1 parent 2a85bed commit 2f0a338
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 67 deletions.
168 changes: 101 additions & 67 deletions Lib/_pylong.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,86 @@
except ImportError:
_decimal = None

# A number of functions have this form, where `w` is a desired number of
# digits in base `base`:
#
# def inner(...w...):
# if w <= LIMIT:
# return something
# lo = w >> 1
# hi = w - lo
# something involving base**lo, inner(...lo...), j, and inner(...hi...)
# figure out largest w needed
# result = inner(w)
#
# They all had some on-the-fly scheme to cache `base**lo` results for reuse.
# Power is costly.
#
# This routine aims to compute all amd only the needed powers in advance, as
# efficiently as reasonably possible. This isn't trivial, and all the
# on-the-fly methods did needless work in many cases. The driving code above
# changes to:
#
# figure out largest w needed
# mycache = compute_powers(w, base, LIMIT)
# result = inner(w)
#
# and `mycache[lo]` replaces `base**lo` in the inner function.
#
# While this does give minor speedups (a few percent at best), the primary
# intent is to simplify the functions using this, by eliminating the need for
# them to craft their own ad-hoc caching schemes.
def compute_powers(w, base, more_than, show=False):
seen = set()
need = set()
ws = {w}
while ws:
w = ws.pop() # any element is fine to use next
if w in seen or w <= more_than:
continue
seen.add(w)
lo = w >> 1
# only _need_ lo here; some other path may, or may not, need hi
need.add(lo)
ws.add(lo)
if w & 1:
ws.add(lo + 1)

d = {}
if not need:
return d
it = iter(sorted(need))
first = next(it)
if show:
print("pow at", first)
d[first] = base ** first
for this in it:
if this - 1 in d:
if show:
print("* base at", this)
d[this] = d[this - 1] * base # cheap
else:
lo = this >> 1
hi = this - lo
assert lo in d
if show:
print("square at", this)
# Multiplying a bigint by itself (same object!) is about twice
# as fast in CPython.
sq = d[lo] * d[lo]
if hi != lo:
assert hi == lo + 1
if show:
print(" and * base")
sq *= base
d[this] = sq
return d

_unbounded_dec_context = decimal.getcontext().copy()
_unbounded_dec_context.prec = decimal.MAX_PREC
_unbounded_dec_context.Emax = decimal.MAX_EMAX
_unbounded_dec_context.Emin = decimal.MIN_EMIN
_unbounded_dec_context.traps[decimal.Inexact] = 1 # sanity check

def int_to_decimal(n):
"""Asymptotically fast conversion of an 'int' to Decimal."""
Expand All @@ -33,57 +113,32 @@ def int_to_decimal(n):
# "clever" recursive way. If we want a string representation, we
# apply str to _that_.

D = decimal.Decimal
D2 = D(2)

BITLIM = 128

mem = {}

def w2pow(w):
"""Return D(2)**w and store the result. Also possibly save some
intermediate results. In context, these are likely to be reused
across various levels of the conversion to Decimal."""
if (result := mem.get(w)) is None:
if w <= BITLIM:
result = D2**w
elif w - 1 in mem:
result = (t := mem[w - 1]) + t
else:
w2 = w >> 1
# If w happens to be odd, w-w2 is one larger then w2
# now. Recurse on the smaller first (w2), so that it's
# in the cache and the larger (w-w2) can be handled by
# the cheaper `w-1 in mem` branch instead.
result = w2pow(w2) * w2pow(w - w2)
mem[w] = result
return result
from decimal import Decimal as D
BITLIM = 200

# Don't bother caching the "lo" mask in this; the time to compute it is
# tiny compared to the multiply.
def inner(n, w):
if w <= BITLIM:
return D(n)
w2 = w >> 1
hi = n >> w2
lo = n - (hi << w2)
return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2)

with decimal.localcontext() as ctx:
ctx.prec = decimal.MAX_PREC
ctx.Emax = decimal.MAX_EMAX
ctx.Emin = decimal.MIN_EMIN
ctx.traps[decimal.Inexact] = 1
lo = n & ((1 << w2) - 1)
return inner(lo, w2) + inner(hi, w - w2) * w2pow[w2]

with decimal.localcontext(_unbounded_dec_context):
nbits = n.bit_length()
w2pow = compute_powers(nbits, D(2), BITLIM)
if n < 0:
negate = True
n = -n
else:
negate = False
result = inner(n, n.bit_length())
result = inner(n, nbits)
if negate:
result = -result
return result


def int_to_decimal_string(n):
"""Asymptotically fast conversion of an 'int' to a decimal string."""
w = n.bit_length()
Expand All @@ -97,14 +152,13 @@ def int_to_decimal_string(n):
# available. This algorithm is asymptotically worse than the algorithm
# using the decimal module, but better than the quadratic time
# implementation in longobject.c.

DIGLIM = 1000
def inner(n, w):
if w <= 1000:
if w <= DIGLIM:
return str(n)
w2 = w >> 1
d = pow10_cache.get(w2)
if d is None:
d = pow10_cache[w2] = 5**w2 << w2 # 10**i = (5*2)**i = 5**i * 2**i
hi, lo = divmod(n, d)
hi, lo = divmod(n, pow10[w2])
return inner(hi, w - w2) + inner(lo, w2).zfill(w2)

# The estimation of the number of decimal digits.
Expand All @@ -115,7 +169,9 @@ def inner(n, w):
# only if the number has way more than 10**15 digits, that exceeds
# the 52-bit physical address limit in both Intel64 and AMD64.
w = int(w * 0.3010299956639812 + 1) # log10(2)
pow10_cache = {}
pow10 = compute_powers(w, 5, DIGLIM)
for k, v in pow10.items():
pow10[k] = v << k # 5**k << k == 5**k * 2**k == 10**k
if n < 0:
n = -n
sign = '-'
Expand All @@ -128,7 +184,6 @@ def inner(n, w):
s = s.lstrip('0')
return sign + s


def _str_to_int_inner(s):
"""Asymptotically fast conversion of a 'str' to an 'int'."""

Expand All @@ -144,35 +199,15 @@ def _str_to_int_inner(s):

DIGLIM = 2048

mem = {}

def w5pow(w):
"""Return 5**w and store the result.
Also possibly save some intermediate results. In context, these
are likely to be reused across various levels of the conversion
to 'int'.
"""
if (result := mem.get(w)) is None:
if w <= DIGLIM:
result = 5**w
elif w - 1 in mem:
result = mem[w - 1] * 5
else:
w2 = w >> 1
# If w happens to be odd, w-w2 is one larger then w2
# now. Recurse on the smaller first (w2), so that it's
# in the cache and the larger (w-w2) can be handled by
# the cheaper `w-1 in mem` branch instead.
result = w5pow(w2) * w5pow(w - w2)
mem[w] = result
return result

def inner(a, b):
if b - a <= DIGLIM:
return int(s[a:b])
mid = (a + b + 1) >> 1
return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid))
return (inner(mid, b)
+ ((inner(a, mid) * w5pow[b - mid])
<< (b - mid)))

w5pow = compute_powers(len(s), 5, DIGLIM)
return inner(0, len(s))


Expand All @@ -186,7 +221,6 @@ def int_from_string(s):
s = s.rstrip().replace('_', '')
return _str_to_int_inner(s)


def str_to_int(s):
"""Asymptotically fast version of decimal string to 'int' conversion."""
# FIXME: this doesn't support the full syntax that int() supports.
Expand Down
12 changes: 12 additions & 0 deletions Lib/test/test_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,18 @@ def test_pylong_misbehavior_error_path_from_str(
with self.assertRaises(RuntimeError):
int(big_value)

def test_pylong_roundtrip(self):
from random import randrange, getrandbits
bits = 5000
while bits <= 1_000_000:
bits += randrange(-100, 101) # break bitlength patterns
hibit = 1 << (bits - 1)
n = hibit | getrandbits(bits - 1)
assert n.bit_length() == bits
sn = str(n)
self.assertFalse(sn.startswith('0'))
self.assertEqual(n, int(sn))
bits <<= 1

if __name__ == "__main__":
unittest.main()

0 comments on commit 2f0a338

Please sign in to comment.