Skip to content

Commit

Permalink
Refactor constructors using unsafe_rational
Browse files Browse the repository at this point in the history
  • Loading branch information
Liozou committed Apr 22, 2020
1 parent f659f84 commit 303aed8
Showing 1 changed file with 57 additions and 52 deletions.
109 changes: 57 additions & 52 deletions base/rational.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
unsafe_rational(num::Integer, den::Integer; checksign=false)
unsafe_rational{T<:Integer}(num::Integer, den::Integer; checksign::Bool=false)
Create a `Rational` with numerator `num` and denominator `den`.
The `unsafe` prefix on this function indicates that no check is performed on
`num` and `den` to ensure that the resulting `Rational` is well-formed.
A `Rational` is well-formed if its numerator and denominator are coprime and if
the denominator is non-negative.
`checksign` optionally specifies whether to check the sign of `den`.
Ill-formed `Rational`s result in undefined behaviour.
"""
struct unsafe_rational{T<:Integer} <: Function end

"""
Rational{T<:Integer} <: Real
Expand All @@ -10,42 +27,30 @@ struct Rational{T<:Integer} <: Real
num::T
den::T

function Rational{T}(num::Integer, den::Integer) where T<:Integer
num == den == zero(T) && __throw_rational_argerror_zero(T)
num2, den2 = divgcd(num, den)
if T<:Signed && signbit(den2)
den2 = -den2
signbit(den2) && __throw_rational_argerror_typemin(T)
num2 = -num2
end
return new(num2, den2)
end

function Rational{T}(num::Integer, den::Integer, ::Val{false}) where T<:Integer
# Used when num and den are only known to be coprime and not both equal to 0
if T<:Signed && signbit(den)
function (::Type{unsafe_rational{T}})(num::Integer, den::Integer; checksign::Bool=false) where T<:Integer
if checksign && T<:Signed && signbit(den)
den = -den
signbit(den) && __throw_rational_argerror_typemin(T)
num = -num
end
return new(num, den)
end

function Rational{T}(num::Integer, den::Integer, ::Val{true}) where T<:Integer
# Used to skip all checks. This means we know num and den are coprime
# and cannot both be 0 and den >= 0
return new(num, den)
return new{T}(num, den)
end

end

@noinline __throw_rational_argerror_zero(T) = throw(ArgumentError("invalid rational: zero($T)//zero($T)"))
@noinline __throw_rational_argerror_typemin(T) = throw(ArgumentError("invalid rational: denominator can't be typemin($T)"))

function Rational{T}(num::Integer, den::Integer) where T<:Integer
num == den == zero(T) && __throw_rational_argerror_zero(T)
num2, den2 = divgcd(num, den)
return unsafe_rational{T}(num2, den2; checksign=true)
end

Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n, d)
Rational(n::T, d::T, check::Val{b}) where {T<:Integer, b} = Rational{T}(n, d, check)
Rational(n::Integer, d::Integer) = Rational(promote(n,d)...)
Rational(n::Integer, d::Integer, check::Val{b}) where {b} = Rational(promote(n, d)..., check)
Rational(n::Integer) = Rational(n, one(n), Val(true))
Rational(n::Integer, d::Integer) = Rational(promote(n, d)...)
unsafe_rational(n::T, d::T; checksign=false) where {T<:Integer} = unsafe_rational{T}(n, d; checksign)
unsafe_rational(n::Integer, d::Integer; checksign=false) = unsafe_rational(promote(n, d)...; checksign)
Rational(n::Integer) = unsafe_rational(n, one(n))

function divgcd(x::Integer,y::Integer)
g = gcd(x,y)
Expand All @@ -70,16 +75,16 @@ julia> (3 // 5) // (2 // 1)

function //(x::Rational, y::Integer)
xn,yn = divgcd(x.num,y)
Rational(xn, checked_mul(x.den, yn), Val(false))
unsafe_rational(xn, checked_mul(x.den, yn); checksign=true)
end
function //(x::Integer, y::Rational)
xn,yn = divgcd(x,y.num)
Rational(checked_mul(xn, y.den), yn, Val(false))
unsafe_rational(checked_mul(xn, y.den), yn; checksign=true)
end
function //(x::Rational, y::Rational)
xn,yn = divgcd(x.num,y.num)
xd,yd = divgcd(x.den,y.den)
Rational(checked_mul(xn, yd), checked_mul(xd, yn), Val(false))
unsafe_rational(checked_mul(xn, yd), checked_mul(xd, yn); checksign=true)
end

//(x::Complex, y::Real) = complex(real(x)//y, imag(x)//y)
Expand All @@ -103,8 +108,8 @@ function write(s::IO, z::Rational)
write(s,numerator(z),denominator(z))
end

Rational{T}(x::Rational) where {T<:Integer} = Rational{T}(convert(T,x.num), convert(T,x.den), Val(true))
Rational{T}(x::Integer) where {T<:Integer} = Rational{T}(convert(T,x), one(T), Val(true))
Rational{T}(x::Rational) where {T<:Integer} = unsafe_rational{T}(convert(T,x.num), convert(T,x.den))
Rational{T}(x::Integer) where {T<:Integer} = unsafe_rational{T}(convert(T,x), one(T))

Rational(x::Rational) = x

Expand All @@ -127,7 +132,7 @@ end
Rational(x::Float64) = Rational{Int64}(x)
Rational(x::Float32) = Rational{Int}(x)

big(q::Rational) = Rational(big(numerator(q)), big(denominator(q)), Val(true))
big(q::Rational) = unsafe_rational(big(numerator(q)), big(denominator(q)))

big(z::Complex{<:Rational{<:Integer}}) = Complex{Rational{BigInt}}(z)

Expand Down Expand Up @@ -163,7 +168,7 @@ function rationalize(::Type{T}, x::AbstractFloat, tol::Real) where T<:Integer
throw(OverflowError("cannot negate unsigned number"))
end
isnan(x) && return T(x)//one(T)
isinf(x) && return Rational(x < 0 ? -one(T) : one(T), zero(T), Val(true))
isinf(x) && return unsafe_rational(x < 0 ? -one(T) : one(T), zero(T))

p, q = (x < 0 ? -one(T) : one(T)), zero(T)
pp, qq = zero(T), one(T)
Expand Down Expand Up @@ -256,23 +261,23 @@ denominator(x::Rational) = x.den

sign(x::Rational) = oftype(x, sign(x.num))
signbit(x::Rational) = signbit(x.num)
copysign(x::Rational, y::Real) = Rational(copysign(x.num, y), x.den, Val(true))
copysign(x::Rational, y::Rational) = Rational(copysign(x.num, y.num), x.den, Val(true))
copysign(x::Rational, y::Real) = unsafe_rational(copysign(x.num, y), x.den)
copysign(x::Rational, y::Rational) = unsafe_rational(copysign(x.num, y.num), x.den)

abs(x::Rational) = Rational(abs(x.num), x.den)

typemin(::Type{Rational{T}}) where {T<:Signed} = Rational(-one(T), zero(T), Val(true))
typemin(::Type{Rational{T}}) where {T<:Integer} = Rational(zero(T), one(T), Val(true))
typemax(::Type{Rational{T}}) where {T<:Integer} = Rational(one(T), zero(T), Val(true))
typemin(::Type{Rational{T}}) where {T<:Signed} = unsafe_rational(-one(T), zero(T))
typemin(::Type{Rational{T}}) where {T<:Integer} = unsafe_rational(zero(T), one(T))
typemax(::Type{Rational{T}}) where {T<:Integer} = unsafe_rational(one(T), zero(T))

isinteger(x::Rational) = x.den == 1

+(x::Rational) = Rational(+x.num, x.den, Val(true))
-(x::Rational) = Rational(-x.num, x.den, Val(true))
+(x::Rational) = unsafe_rational(+x.num, x.den)
-(x::Rational) = unsafe_rational(-x.num, x.den)

function -(x::Rational{T}) where T<:BitSigned
x.num == typemin(T) && throw(OverflowError("rational numerator is typemin(T)"))
Rational(-x.num, x.den, Val(true))
unsafe_rational(-x.num, x.den)
end
function -(x::Rational{T}) where T<:Unsigned
x.num != zero(T) && throw(OverflowError("cannot negate unsigned number"))
Expand All @@ -287,14 +292,14 @@ for (op,chop) in ((:+,:checked_add), (:-,:checked_sub), (:rem,:rem), (:mod,:mod)
end

function ($op)(x::Rational, y::Integer)
Rational(($chop)(x.num, checked_mul(x.den, y)), x.den, Val(true))
unsafe_rational(($chop)(x.num, checked_mul(x.den, y)), x.den)
end
end
end
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub))
@eval begin
function ($op)(y::Integer, x::Rational)
Rational(($chop)(checked_mul(x.den, y), x.num), x.den, Val(true))
unsafe_rational(($chop)(checked_mul(x.den, y), x.num), x.den)
end
end
end
Expand All @@ -309,20 +314,20 @@ end
function *(x::Rational, y::Rational)
xn, yd = divgcd(x.num, y.den)
xd, yn = divgcd(x.den, y.num)
Rational(checked_mul(xn, yn), checked_mul(xd, yd), Val(true))
unsafe_rational(checked_mul(xn, yn), checked_mul(xd, yd))
end
function *(x::Rational, y::Integer)
xd, yn = divgcd(x.den, y)
Rational(checked_mul(x.num, yn), xd, Val(true))
unsafe_rational(checked_mul(x.num, yn), xd)
end
function *(y::Integer, x::Rational)
yn, xd = divgcd(y, x.den)
Rational(checked_mul(yn, x.num), xd, Val(true))
unsafe_rational(checked_mul(yn, x.num), xd)
end
/(x::Rational, y::Union{Rational, Integer}) = x//y
/(x::Integer, y::Rational) = x//y
/(x::Rational, y::Complex{<:Union{Integer,Rational}}) = x//y
inv(x::Rational) = Rational(x.den, x.num, Val(false))
inv(x::Rational) = unsafe_rational(x.den, x.num; checksign=true)

fma(x::Rational, y::Rational, z::Rational) = x*y+z

Expand Down Expand Up @@ -441,7 +446,7 @@ function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:Nearest}) w
if denominator(x) == zero(Tr) && T <: Integer
throw(DivideError())
elseif denominator(x) == zero(Tr)
return convert(T, copysign(Rational(one(Tr), zero(Tr), Val(true)), numerator(x)))
return convert(T, copysign(unsafe_rational(one(Tr), zero(Tr)), numerator(x)))
end
q,r = divrem(numerator(x), denominator(x))
s = q
Expand All @@ -455,7 +460,7 @@ function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTies
if denominator(x) == zero(Tr) && T <: Integer
throw(DivideError())
elseif denominator(x) == zero(Tr)
return convert(T, copysign(Rational(one(Tr), zero(Tr), Val(true)), numerator(x)))
return convert(T, copysign(unsafe_rational(one(Tr), zero(Tr)), numerator(x)))
end
q,r = divrem(numerator(x), denominator(x))
s = q
Expand All @@ -469,7 +474,7 @@ function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTies
if denominator(x) == zero(Tr) && T <: Integer
throw(DivideError())
elseif denominator(x) == zero(Tr)
return convert(T, copysign(Rational(one(Tr), zero(Tr), Val(true)), numerator(x)))
return convert(T, copysign(unsafe_rational(one(Tr), zero(Tr)), numerator(x)))
end
q,r = divrem(numerator(x), denominator(x))
s = q
Expand Down Expand Up @@ -513,8 +518,8 @@ end

float(::Type{Rational{T}}) where {T<:Integer} = float(T)

gcd(x::Rational, y::Rational) = Rational(gcd(x.num, y.num), lcm(x.den, y.den), Val(true))
lcm(x::Rational, y::Rational) = Rational(lcm(x.num, y.num), gcd(x.den, y.den), Val(true))
gcd(x::Rational, y::Rational) = unsafe_rational(gcd(x.num, y.num), lcm(x.den, y.den))
lcm(x::Rational, y::Rational) = unsafe_rational(lcm(x.num, y.num), gcd(x.den, y.den))
function gcdx(x::Rational, y::Rational)
c = gcd(x, y)
if iszero(c.num)
Expand Down

0 comments on commit 303aed8

Please sign in to comment.