Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast multiplicative inverses #15357

Merged
merged 4 commits into from
Mar 21, 2016
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions base/multinverses.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
module MultiplicativeInverses

import Base: div, divrem, rem
using Base: LinearFast, LinearSlow, tail
export multiplicativeinverse

unsigned_type(::Int8) = UInt8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if these should just be methods of unsigned{T<:Signed}(::Type{T}) ?

unsigned_type(::Int16) = UInt16
unsigned_type(::Int32) = UInt32
unsigned_type(::Int64) = UInt64
unsigned_type(::Int128) = UInt128

abstract MultiplicativeInverse{T}

immutable SignedMultiplicativeInverse{T<:Signed} <: MultiplicativeInverse{T}
divisor::T
multiplier::T
addmul::Int8
shift::UInt8

function SignedMultiplicativeInverse(d::T)
d == 0 && error("cannot compute magic for d == $d")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArgumentError ? or DivideError ?

ut = unsigned_type(d)
signedmin = reinterpret(ut, typemin(d))

ad::ut = abs(d)
t::ut = signedmin + signbit(d)
anc::ut = t - 1 - rem(t, ad) # absolute value of nc
p = sizeof(d)*8 - 1 # initialize p
q1::ut, r1::ut = divrem(signedmin, anc)
q2::ut, r2::ut = divrem(signedmin, ad)
while true
p += 1
q1 *= 2 # update q1 = 2p/abs(nc)
r1 *= 2 # update r1 = rem(2p/abs(nc))
if r1 >= anc # must be unsigned comparison
q1 += 1
r1 -= anc
end
q2 *= 2 # update q2 = 2p/abs(d)
r2 *= 2 # update r2 = rem(2p/abs(d))
if r2 >= ad # must be unsigned comparison
q2 += 1
r2 -= ad
end
delta::ut = ad - r2
(q1 < delta || (q1 == delta && r1 == 0)) || break
end

m = flipsign((q2 + 1) % T, d) # resulting magic number
s = p - sizeof(d)*8 # resulting shift
new(d, m, d > 0 && m < 0 ? Int8(1) : d < 0 && m > 0 ? Int8(-1) : Int8(0), UInt8(s))
end
end
SignedMultiplicativeInverse(x::Signed) = SignedMultiplicativeInverse{typeof(x)}(x)

immutable UnsignedMultiplicativeInverse{T<:Unsigned} <: MultiplicativeInverse{T}
divisor::T
multiplier::T
add::Bool
shift::UInt8

function UnsignedMultiplicativeInverse(d::T)
d == 0 && error("cannot compute magic for d == $d")
u2 = convert(T, 2)
add = false
signedmin::typeof(d) = one(d) << (sizeof(d)*8-1)
signedmax::typeof(d) = signedmin - 1
allones = (zero(d) - 1) % T

nc::typeof(d) = allones - rem(convert(T, allones - d), d)
p = 8*sizeof(d) - 1 # initialize p
q1::typeof(d), r1::typeof(d) = divrem(signedmin, nc)
q2::typeof(d), r2::typeof(d) = divrem(signedmax, d)
while true
p += 1
if r1 >= convert(T, nc - r1)
q1 = q1 + q1 + T(1) # update q1
r1 = r1 + r1 - nc # update r1
else
q1 = q1 + q1 # update q1
r1 = r1 + r1 # update r1
end
if convert(T, r2 + T(1)) >= convert(T, d - r2)
add |= q2 >= signedmax
q2 = q2 + q2 + 1 # update q2
r2 = r2 + r2 + T(1) - d # update r2
else
add |= q2 >= signedmin
q2 = q2 + q2 # update q2
r2 = r2 + r2 + T(1) # update r2
end
delta::typeof(d) = d - 1 - r2
(p < sizeof(d)*16 && (q1 < delta || (q1 == delta && r1 == 0))) || break
end
m = q2 + 1 # resulting magic number
s = p - sizeof(d)*8 - add # resulting shift
new(d, m % T, add, s % UInt8)
end
end
UnsignedMultiplicativeInverse(x::Unsigned) = UnsignedMultiplicativeInverse{typeof(x)}(x)

function div{T}(a::T, b::SignedMultiplicativeInverse{T})
x = ((widen(a)*b.multiplier) >>> sizeof(a)*8) % T
x += (a*b.addmul) % T
ifelse(abs(b.divisor) == 1, a*b.divisor, (signbit(x) + (x >> b.shift)) % T)
end
function div{T}(a::T, b::UnsignedMultiplicativeInverse{T})
x = ((widen(a)*b.multiplier) >>> sizeof(a)*8) % T
x = ifelse(b.add, convert(T, convert(T, (convert(T, a - x) >>> 1)) + x), x)
ifelse(b.divisor == 1, a, x >>> b.shift)
end

rem{T}(a::T, b::MultiplicativeInverse{T}) =
a - div(a, b)*b.divisor

function divrem{T}(a::T, b::MultiplicativeInverse{T})
d = div(a, b)
(d, a - d*b.divisor)
end

multiplicativeinverse(x::Signed) = SignedMultiplicativeInverse(x)
multiplicativeinverse(x::Unsigned) = UnsignedMultiplicativeInverse(x)

end
2 changes: 2 additions & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ importall .Rounding
include("float.jl")
include("complex.jl")
include("rational.jl")
include("multinverses.jl")
using .MultiplicativeInverses
include("abstractarraymath.jl")
include("arraymath.jl")

Expand Down
13 changes: 13 additions & 0 deletions test/numbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2858,3 +2858,16 @@ for T in (Int8, Int16, Int32, Int64, Bool)
@test round(T, true//true) === one(T)
@test round(T, false//true) === zero(T)
end

# multiplicative inverses
function testmi(numrange, denrange)
for d in denrange
d == 0 && continue
fastd = Base.multiplicativeinverse(d)
for n in numrange
@test div(n,d) == div(n,fastd)
end
end
end
testmi(-1000:1000, -100:100)
@test_throws ErrorException Base.multiplicativeinverse(0)