Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Rationals by avoiding unnecessary divgcd #35492

Merged
merged 10 commits into from
May 4, 2020
127 changes: 80 additions & 47 deletions base/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,34 @@ struct Rational{T<:Integer} <: Real
num::T
den::T

function Rational{T}(num::Integer, den::Integer) where T<:Integer
num == den == zero(T) && __throw_rational_argerror_zero(T)
num2, den2 = divgcd(num, den)
if T<:Signed && signbit(den2)
den2 = -den2
signbit(den2) && __throw_rational_argerror_typemin(T)
num2 = -num2
end
return new(num2, den2)
# Unexported inner constructor of Rational that bypasses all checks
global unsafe_rational(::Type{T}, num, den) where {T} = new{T}(num, den)
Liozou marked this conversation as resolved.
Show resolved Hide resolved
end

unsafe_rational(num::T, den::T) where {T<:Integer} = unsafe_rational(T, num, den)
unsafe_rational(num::Integer, den::Integer) = unsafe_rational(promote(num, den)...)

@noinline __throw_rational_argerror_typemin(T) = throw(ArgumentError("invalid rational: denominator can't be typemin($T)"))
function checked_den(num::T, den::T) where T<:Integer
if signbit(den)
den = -den
signbit(den) && __throw_rational_argerror_typemin(T)
num = -num
end
return unsafe_rational(T, num, den)
end
checked_den(num::Integer, den::Integer) = checked_den(promote(num, den)...)

@noinline __throw_rational_argerror_zero(T) = throw(ArgumentError("invalid rational: zero($T)//zero($T)"))
@noinline __throw_rational_argerror_typemin(T) = throw(ArgumentError("invalid rational: denominator can't be typemin($T)"))
function Rational{T}(num::Integer, den::Integer) where T<:Integer
iszero(den) && iszero(num) && __throw_rational_argerror_zero(T)
num, den = divgcd(num, den)
return checked_den(T(num), T(den))
end

Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n,d)
Rational(n::Integer, d::Integer) = Rational(promote(n,d)...)
Rational(n::Integer) = Rational(n,one(n))
Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n, d)
Rational(n::Integer, d::Integer) = Rational(promote(n, d)...)
Rational(n::Integer) = unsafe_rational(n, one(n))

function divgcd(x::Integer,y::Integer)
g = gcd(x,y)
Expand All @@ -50,20 +61,20 @@ julia> (3 // 5) // (2 // 1)
//(n::Integer, d::Integer) = Rational(n,d)

function //(x::Rational, y::Integer)
xn,yn = divgcd(x.num,y)
xn//checked_mul(x.den,yn)
xn, yn = divgcd(x.num,y)
checked_den(xn, checked_mul(x.den, yn))
end
function //(x::Integer, y::Rational)
xn,yn = divgcd(x,y.num)
checked_mul(xn,y.den)//yn
xn, yn = divgcd(x,y.num)
checked_den(checked_mul(xn, y.den), yn)
end
function //(x::Rational, y::Rational)
xn,yn = divgcd(x.num,y.num)
xd,yd = divgcd(x.den,y.den)
checked_mul(xn,yd)//checked_mul(xd,yn)
checked_den(checked_mul(xn, yd), checked_mul(xd, yn))
end

//(x::Complex, y::Real) = complex(real(x)//y,imag(x)//y)
//(x::Complex, y::Real) = complex(real(x)//y, imag(x)//y)
//(x::Number, y::Complex) = x*conj(y)//abs2(y)


Expand All @@ -84,8 +95,12 @@ function write(s::IO, z::Rational)
write(s,numerator(z),denominator(z))
end

Rational{T}(x::Rational) where {T<:Integer} = Rational{T}(convert(T,x.num), convert(T,x.den))
Rational{T}(x::Integer) where {T<:Integer} = Rational{T}(convert(T,x), convert(T,1))
function Rational{T}(x::Rational) where T<:Integer
unsafe_rational(T, convert(T, x.num), convert(T, x.den))
end
function Rational{T}(x::Integer) where T<:Integer
unsafe_rational(T, convert(T, x), one(T))
end

Rational(x::Rational) = x

Expand All @@ -108,7 +123,7 @@ end
Rational(x::Float64) = Rational{Int64}(x)
Rational(x::Float32) = Rational{Int}(x)

big(q::Rational) = big(numerator(q))//big(denominator(q))
big(q::Rational) = unsafe_rational(big(numerator(q)), big(denominator(q)))

big(z::Complex{<:Rational{<:Integer}}) = Complex{Rational{BigInt}}(z)

Expand All @@ -118,6 +133,8 @@ promote_rule(::Type{Rational{T}}, ::Type{S}) where {T<:Integer,S<:AbstractFloat}

widen(::Type{Rational{T}}) where {T} = Rational{widen(T)}

@noinline __throw_negate_unsigned() = throw(OverflowError("cannot negate unsigned number"))

"""
rationalize([T<:Integer=Int,] x; tol::Real=eps(x))

Expand All @@ -140,8 +157,9 @@ function rationalize(::Type{T}, x::AbstractFloat, tol::Real) where T<:Integer
if tol < 0
throw(ArgumentError("negative tolerance $tol"))
end
T<:Unsigned && x < 0 && __throw_negate_unsigned()
isnan(x) && return T(x)//one(T)
isinf(x) && return (x < 0 ? -one(T) : one(T))//zero(T)
isinf(x) && return unsafe_rational(x < 0 ? -one(T) : one(T), zero(T))

p, q = (x < 0 ? -one(T) : one(T)), zero(T)
pp, qq = zero(T), one(T)
Expand Down Expand Up @@ -234,59 +252,74 @@ denominator(x::Rational) = x.den

sign(x::Rational) = oftype(x, sign(x.num))
signbit(x::Rational) = signbit(x.num)
copysign(x::Rational, y::Real) = copysign(x.num,y) // x.den
copysign(x::Rational, y::Rational) = copysign(x.num,y.num) // x.den
copysign(x::Rational, y::Real) = unsafe_rational(copysign(x.num, y), x.den)
copysign(x::Rational, y::Rational) = unsafe_rational(copysign(x.num, y.num), x.den)

abs(x::Rational) = Rational(abs(x.num), x.den)

typemin(::Type{Rational{T}}) where {T<:Integer} = -one(T)//zero(T)
typemax(::Type{Rational{T}}) where {T<:Integer} = one(T)//zero(T)
typemin(::Type{Rational{T}}) where {T<:Signed} = unsafe_rational(T, -one(T), zero(T))
typemin(::Type{Rational{T}}) where {T<:Integer} = unsafe_rational(T, zero(T), one(T))
typemax(::Type{Rational{T}}) where {T<:Integer} = unsafe_rational(T, one(T), zero(T))

isinteger(x::Rational) = x.den == 1

+(x::Rational) = (+x.num) // x.den
-(x::Rational) = (-x.num) // x.den
+(x::Rational) = unsafe_rational(+x.num, x.den)
-(x::Rational) = unsafe_rational(-x.num, x.den)

function -(x::Rational{T}) where T<:BitSigned
x.num == typemin(T) && throw(OverflowError("rational numerator is typemin(T)"))
(-x.num) // x.den
x.num == typemin(T) && __throw_rational_numerator_typemin(T)
unsafe_rational(-x.num, x.den)
end
@noinline __throw_rational_numerator_typemin(T) = throw(OverflowError("rational numerator is typemin($T)"))

function -(x::Rational{T}) where T<:Unsigned
x.num != zero(T) && throw(OverflowError("cannot negate unsigned number"))
x.num != zero(T) && __throw_negate_unsigned()
x
end

for (op,chop) in ((:+,:checked_add), (:-,:checked_sub),
(:rem,:rem), (:mod,:mod))
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub), (:rem,:rem), (:mod,:mod))
@eval begin
function ($op)(x::Rational, y::Rational)
xd, yd = divgcd(x.den, y.den)
Rational(($chop)(checked_mul(x.num,yd), checked_mul(y.num,xd)), checked_mul(x.den,yd))
end

function ($op)(x::Rational, y::Integer)
Rational(($chop)(x.num, checked_mul(x.den, y)), x.den)
unsafe_rational(($chop)(x.num, checked_mul(x.den, y)), x.den)
end

end
end
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub))
@eval begin
function ($op)(y::Integer, x::Rational)
unsafe_rational(($chop)(checked_mul(x.den, y), x.num), x.den)
end
end
end
for (op,chop) in ((:rem,:rem), (:mod,:mod))
@eval begin
function ($op)(y::Integer, x::Rational)
Rational(($chop)(checked_mul(x.den, y), x.num), x.den)
end
end
end

function *(x::Rational, y::Rational)
xn,yd = divgcd(x.num,y.den)
xd,yn = divgcd(x.den,y.num)
checked_mul(xn,yn) // checked_mul(xd,yd)
xn, yd = divgcd(x.num, y.den)
xd, yn = divgcd(x.den, y.num)
unsafe_rational(checked_mul(xn, yn), checked_mul(xd, yd))
end
function *(x::Rational, y::Integer)
xd, yn = divgcd(x.den, y)
checked_mul(x.num, yn) // xd
unsafe_rational(checked_mul(x.num, yn), xd)
end
function *(y::Integer, x::Rational)
yn, xd = divgcd(y, x.den)
unsafe_rational(checked_mul(yn, x.num), xd)
end
*(x::Integer, y::Rational) = *(y, x)
/(x::Rational, y::Rational) = x//y
/(x::Rational, y::Complex{<:Union{Integer,Rational}}) = x//y
inv(x::Rational) = Rational(x.den, x.num)
/(x::Rational, y::Union{Rational, Integer, Complex{<:Union{Integer,Rational}}}) = x//y
/(x::Union{Integer, Complex{<:Union{Integer,Rational}}}, y::Rational) = x//y
inv(x::Rational{T}) where {T} = checked_den(x.den, x.num)

fma(x::Rational, y::Rational, z::Rational) = x*y+z

Expand Down Expand Up @@ -403,7 +436,7 @@ round(x::Rational, r::RoundingMode=RoundNearest) = round(typeof(x), x, r)

function round(::Type{T}, x::Rational{Tr}, r::RoundingMode=RoundNearest) where {T,Tr}
if iszero(denominator(x)) && !(T <: Integer)
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
return convert(T, copysign(unsafe_rational(one(Tr), zero(Tr)), numerator(x)))
end
convert(T, div(numerator(x), denominator(x), r))
end
Expand Down Expand Up @@ -437,8 +470,8 @@ end

float(::Type{Rational{T}}) where {T<:Integer} = float(T)

gcd(x::Rational, y::Rational) = gcd(x.num, y.num) // lcm(x.den, y.den)
lcm(x::Rational, y::Rational) = lcm(x.num, y.num) // gcd(x.den, y.den)
gcd(x::Rational, y::Rational) = unsafe_rational(gcd(x.num, y.num), lcm(x.den, y.den))
lcm(x::Rational, y::Rational) = unsafe_rational(lcm(x.num, y.num), gcd(x.den, y.den))
function gcdx(x::Rational, y::Rational)
c = gcd(x, y)
if iszero(c.num)
Expand Down
14 changes: 14 additions & 0 deletions test/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Test
@test 5//0 == 1//0
@test -1//0 == -1//0
@test -7//0 == -1//0
@test (-1//2) // (-2//5) == 5//4

@test_throws OverflowError -(0x01//0x0f)
@test_throws OverflowError -(typemin(Int)//1)
Expand All @@ -26,9 +27,13 @@ using Test
@test (typemax(Int)//1) / (typemax(Int)//1) == 1
@test (1//typemax(Int)) / (1//typemax(Int)) == 1
@test_throws OverflowError (1//2)^63
@test inv((1+typemin(Int))//typemax(Int)) == -1
@test_throws ArgumentError inv(typemin(Int)//typemax(Int))
@test_throws ArgumentError Rational(0x1, typemin(Int32))

@test @inferred(rationalize(Int, 3.0, 0.0)) === 3//1
@test @inferred(rationalize(Int, 3.0, 0)) === 3//1
@test_throws OverflowError rationalize(UInt, -2.0)
@test_throws ArgumentError rationalize(Int, big(3.0), -1.)
# issue 26823
@test_throws InexactError rationalize(Int, NaN)
Expand All @@ -38,6 +43,11 @@ using Test
@test -2 // typemin(Int) == -1 // (typemin(Int) >> 1)
@test 2 // typemin(Int) == 1 // (typemin(Int) >> 1)

@test_throws InexactError Rational(UInt(1), typemin(Int32))
@test iszero(Rational{Int}(UInt(0), 1))
@test Rational{BigInt}(UInt(1), Int(-1)) == -1
@test_broken Rational{Int64}(UInt(1), typemin(Int32)) == Int64(1) // Int64(typemin(Int32))

for a = -5:5, b = -5:5
if a == b == 0; continue; end
if ispow2(b)
Expand Down Expand Up @@ -120,6 +130,8 @@ end
@test widen(Rational{T}) == Rational{widen(T)}
end

@test iszero(typemin(Rational{UInt}))

@test Rational(Float32(rand_int)) == Rational(rand_int)

@test Rational(Rational(rand_int)) == Rational(rand_int)
Expand Down Expand Up @@ -548,6 +560,8 @@ end
end
@test 1//2 * 3 == 3//2
@test -3 * (1//2) == -3//2
@test (6//5) // -3 == -2//5
@test -4 // (-6//5) == 10//3

@test_throws OverflowError UInt(1)//2 - 1
@test_throws OverflowError 1 - UInt(5)//2
Expand Down