Skip to content

Commit

Permalink
Merge pull request #43360 from JuliaLang/jn/ranges-last
Browse files Browse the repository at this point in the history
refine and cleanup handling of range arithmetic
  • Loading branch information
vtjnash authored Dec 10, 2021
2 parents b6bca19 + ff185b7 commit 5d41a76
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 132 deletions.
121 changes: 64 additions & 57 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

(:)(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)

(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop-start, 1), stop)
(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop >= start ? stop - start : start - stop, 1), stop)

# promote start and stop, leaving step alone
(:)(start::A, step, stop::C) where {A<:Real,C<:Real} =
Expand Down Expand Up @@ -164,7 +164,7 @@ _range(start::Any , step::Any , stop::Any , len::Any ) = range_error
range_length(len::Integer) = OneTo(len)

# Stop as the only argument
range_stop(stop) = range_start_stop(oneunit(stop), stop)
range_stop(stop) = range_start_stop(oftype(stop, 1), stop)
range_stop(stop::Integer) = range_length(stop)

# Stop and length as the only argument
Expand Down Expand Up @@ -200,10 +200,17 @@ function range_start_step_length(a::T, step, len::Integer) where {T}
_rangestyle(OrderStyle(T), ArithmeticStyle(T), a, step, len)
end

_rangestyle(::Ordered, ::ArithmeticWraps, a::T, step::S, len::Integer) where {T,S} =
StepRange{typeof(a+zero(step)),S}(a, step, a+step*(len-1))
_rangestyle(::Any, ::Any, a::T, step::S, len::Integer) where {T,S} =
StepRangeLen{typeof(a+zero(step)),T,S}(a, step, len)
function _rangestyle(::Ordered, ::ArithmeticWraps, a, step, len::Integer)
start = a + zero(step)
stop = a + step * (len - 1)
T = typeof(start)
return StepRange{T,typeof(step)}(start, step, convert(T, stop))
end
function _rangestyle(::Any, ::Any, a, step, len::Integer)
start = a + zero(step)
T = typeof(a)
return StepRangeLen{typeof(start),T,typeof(step)}(a, step, len)
end

range_start_step_stop(start, step, stop) = start:step:stop

Expand Down Expand Up @@ -306,19 +313,19 @@ struct StepRange{T,S} <: OrdinalRange{T,S}
stop::T

function StepRange{T,S}(start, step, stop) where {T,S}
sta = convert(T, start)
ste = convert(S, step)
sto = convert(T, stop)
new(sta, ste, steprange_last(sta,ste,sto))
start = convert(T, start)
step = convert(S, step)
stop = convert(T, stop)
return new(start, step, steprange_last(start, step, stop))
end
end

# to make StepRange constructor inlineable, so optimizer can see `step` value
function steprange_last(start::T, step, stop) where T
if isa(start,AbstractFloat) || isa(step,AbstractFloat)
function steprange_last(start, step, stop)::typeof(stop)
if isa(start, AbstractFloat) || isa(step, AbstractFloat)
throw(ArgumentError("StepRange should not be used with floating point"))
end
if isa(start,Integer) && !isinteger(start + step)
if isa(start, Integer) && !isinteger(start + step)
throw(ArgumentError("StepRange{<:Integer} cannot have non-integer step"))
end
z = zero(step)
Expand All @@ -335,30 +342,28 @@ function steprange_last(start::T, step, stop) where T
absdiff, absstep = stop > start ? (stop - start, step) : (start - stop, -step)

# Compute remainder as a nonnegative number:
if T <: Signed && absdiff < zero(absdiff)
# handle signed overflow with unsigned rem
remain = convert(T, unsigned(absdiff) % absstep)
if absdiff isa Signed && absdiff < zero(absdiff)
# unlikely, but handle the signed overflow case with unsigned rem
remain = convert(typeof(absdiff), unsigned(absdiff) % absstep)
else
remain = absdiff % absstep
remain = convert(typeof(absdiff), absdiff % absstep)
end
# Move `stop` closer to `start` if there is a remainder:
last = stop > start ? stop - remain : stop + remain
end
end
last
return last
end

function steprange_last_empty(start::Integer, step, stop)
# empty range has a special representation where stop = start-1
# this is needed to avoid the wrap-around that can happen computing
# start - step, which leads to a range that looks very large instead
# of empty.
function steprange_last_empty(start::Integer, step, stop)::typeof(stop)
# empty range has a special representation where stop = start-1,
# which simplifies arithmetic for Signed numbers
if step > zero(step)
last = start - oneunit(stop-start)
last = start - oneunit(step)
else
last = start + oneunit(stop-start)
last = start + oneunit(step)
end
last
return last
end
# For types where x+oneunit(x) may not be well-defined use the user-given value for stop
steprange_last_empty(start, step, stop) = stop
Expand Down Expand Up @@ -388,18 +393,21 @@ UnitRange{Int64}
struct UnitRange{T<:Real} <: AbstractUnitRange{T}
start::T
stop::T
UnitRange{T}(start, stop) where {T<:Real} = new(start, unitrange_last(start,stop))
UnitRange{T}(start::T, stop::T) where {T<:Real} = new(start, unitrange_last(start, stop))
end
UnitRange{T}(start, stop) where {T<:Real} = UnitRange{T}(convert(T, start), convert(T, stop))
UnitRange(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)
UnitRange(start, stop) = UnitRange(promote(start, stop)...)

unitrange_last(::Bool, stop::Bool) = stop
unitrange_last(start::T, stop::T) where {T<:Integer} =
stop >= start ? stop : convert(T,start-oneunit(start-stop))
unitrange_last(start::T, stop::T) where {T} =
stop >= start ? convert(T,start+floor(stop-start)) :
convert(T,start-oneunit(stop-start))
# if stop and start are integral, we know that their difference is a multiple of 1
unitrange_last(start::Integer, stop::Integer) =
stop >= start ? stop : convert(typeof(stop), start - oneunit(start - stop))
# otherwise, use `floor` as a more efficient way to compute modulus with step=1
unitrange_last(start, stop) =
stop >= start ? convert(typeof(stop), start + floor(stop - start)) :
convert(typeof(stop), start - oneunit(start - stop))

unitrange(x) = UnitRange(x)
unitrange(x::AbstractUnitRange) = UnitRange(x) # convenience conversion for promoting the range type

if isdefined(Main, :Base)
# Constant-fold-able indexing into tuples to functionally expose Base.tail and Base.front
Expand Down Expand Up @@ -556,7 +564,7 @@ function LinRange{T}(start, stop, len::Integer) where T
end

function LinRange(start, stop, len::Integer)
T = typeof((stop-start)/len)
T = typeof((zero(stop) - zero(start)) / oneunit(len))
LinRange{T}(start, stop, len)
end

Expand Down Expand Up @@ -642,7 +650,7 @@ length(r::AbstractRange) = error("length implementation missing") # catch mistak
size(r::AbstractRange) = (length(r),)

isempty(r::StepRange) =
# steprange_last_empty(r.start, r.step, r.stop) == r.stop
# steprange_last(r.start, r.step, r.stop) == r.stop
(r.start != r.stop) & ((r.step > zero(r.step)) != (r.stop > r.start))
isempty(r::AbstractUnitRange) = first(r) > last(r)
isempty(r::StepRangeLen) = length(r) == 0
Expand Down Expand Up @@ -689,9 +697,8 @@ firstindex(::LinRange) = 1
# defined between the relevant types
function checked_length(r::OrdinalRange{T}) where T
s = step(r)
# s != 0, by construction, but avoids the division error later
start = first(r)
if s == zero(s) || isempty(r)
if isempty(r)
return Integer(div(start - start, oneunit(s)))
end
stop = last(r)
Expand All @@ -716,9 +723,8 @@ end

function length(r::OrdinalRange{T}) where T
s = step(r)
# s != 0, by construction, but avoids the division error later
start = first(r)
if s == zero(s) || isempty(r)
if isempty(r)
return Integer(div(start - start, oneunit(s)))
end
stop = last(r)
Expand Down Expand Up @@ -756,7 +762,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
# (near typemax) for types with known `unsigned` functions
function length(r::OrdinalRange{T}) where T<:bigints
s = step(r)
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
isempty(r) && return zero(T)
diff = last(r) - first(r)
# if |s| > 1, diff might have overflowed, but unsigned(diff)÷s should
Expand All @@ -773,7 +778,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
end
function checked_length(r::OrdinalRange{T}) where T<:bigints
s = step(r)
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
isempty(r) && return zero(T)
stop, start = last(r), first(r)
# n.b. !(s isa T)
Expand All @@ -800,7 +804,6 @@ let smallints = (Int === Int64 ?
# n.b. !(step isa T)
function length(r::OrdinalRange{<:smallints})
s = step(r)
s == zero(s) && return 0 # unreachable, by construction, but avoids the error case here later
isempty(r) && return 0
return div(Int(last(r)) - Int(first(r)), s) + 1
end
Expand Down Expand Up @@ -962,29 +965,30 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
@boundscheck checkbounds(r, s)

if T === Bool
range(first(s) ? first(r) : last(r), length = Integer(last(s)))
return range(first(s) ? first(r) : last(r), length = last(s))
else
f = first(r)
st = oftype(f, f + first(s)-firstindex(r))
return range(st, length=length(s))
start = oftype(f, f + first(s)-firstindex(r))
return range(start, length=length(s))
end
end

function getindex(r::OneTo{T}, s::OneTo) where T
@inline
@boundscheck checkbounds(r, s)
OneTo(T(s.stop))
return OneTo(T(s.stop))
end

function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
@inline
@boundscheck checkbounds(r, s)

if T === Bool
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Integer(last(s)))
return range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = last(s))
else
st = oftype(first(r), first(r) + s.start-firstindex(r))
return range(st, step=step(s), length=length(s))
f = first(r)
start = oftype(f, f + s.start-firstindex(r))
return range(start, step=step(s), length=length(s))
end
end

Expand All @@ -994,19 +998,22 @@ function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}

if T === Bool
if length(s) == 0
return range(first(r), step=step(r), length=0)
start, len = first(r), 0
elseif length(s) == 1
if first(s)
return range(first(r), step=step(r), length=1)
start, len = first(r), 1
else
return range(first(r), step=step(r), length=0)
start, len = first(r), 0
end
else # length(s) == 2
return range(last(r), step=step(r), length=1)
start, len = last(r), 1
end
return range(start, step=step(r); length=len)
else
st = oftype(r.start, r.start + (first(s)-1)*step(r))
return range(st, step=step(r)*step(s), length=length(s))
f = r.start
st = r.step
start = oftype(f, f + (first(s)-oneunit(first(s)))*st)
return range(start; step=st*step(s), length=length(s))
end
end

Expand Down Expand Up @@ -1235,7 +1242,7 @@ end
issubset(r::OneTo, s::OneTo) = r.stop <= s.stop

issubset(r::AbstractUnitRange{<:Integer}, s::AbstractUnitRange{<:Integer}) =
isempty(r) || first(r) >= first(s) && last(r) <= last(s)
isempty(r) || (first(r) >= first(s) && last(r) <= last(s))

## linear operations on ranges ##

Expand Down
2 changes: 1 addition & 1 deletion stdlib/Dates/src/Dates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ for more information.
"""
module Dates

import Base: ==, div, fld, mod, rem, gcd, lcm, +, -, *, /, %, broadcast
import Base: ==, isless, div, fld, mod, rem, gcd, lcm, +, -, *, /, %, broadcast
using Printf: @sprintf

using Base.Iterators
Expand Down
7 changes: 5 additions & 2 deletions stdlib/Dates/src/periods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ default(p::Union{T,Type{T}}) where {T<:TimePeriod} = T(0)

(-)(x::P) where {P<:Period} = P(-value(x))
==(x::P, y::P) where {P<:Period} = value(x) == value(y)
==(x::Period, y::Period) = (==)(promote(x, y)...)
Base.isless(x::P, y::P) where {P<:Period} = isless(value(x), value(y))
Base.isless(x::Period, y::Period) = isless(promote(x, y)...)

# Period Arithmetic, grouped by dimensionality:
for op in (:+, :-, :lcm, :gcd)
Expand All @@ -97,6 +95,11 @@ end
(*)(A::Period, B::AbstractArray) = Broadcast.broadcast_preserving_zero_d(*, A, B)
(*)(A::AbstractArray, B::Period) = Broadcast.broadcast_preserving_zero_d(*, A, B)

for op in (:(==), :isless, :/, :rem, :mod, :lcm, :gcd)
@eval ($op)(x::Period, y::Period) = ($op)(promote(x, y)...)
end
div(x::Period, y::Period, r::RoundingMode) = div(promote(x, y)..., r)

# intfuncs
Base.gcdx(a::T, b::T) where {T<:Period} = ((g, x, y) = gcdx(value(a), value(b)); return T(g), x, y)
Base.abs(a::T) where {T<:Period} = T(abs(value(a)))
Expand Down
7 changes: 4 additions & 3 deletions stdlib/Dates/src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ end
Base.length(r::StepRange{<:TimeType}) = isempty(r) ? Int64(0) : len(r.start, r.stop, r.step) + 1
# Period ranges hook into Int64 overflow detection
Base.length(r::StepRange{<:Period}) = length(StepRange(value(r.start), value(r.step), value(r.stop)))
Base.checked_length(r::StepRange{<:Period}) = Base.checked_length(StepRange(value(r.start), value(r.step), value(r.stop)))

# Overload Base.steprange_last because `rem` is not overloaded for `TimeType`s
# Overload Base.steprange_last because `step::Period` may be a variable amount of time (e.g. for Month and Year)
function Base.steprange_last(start::T, step, stop) where T<:TimeType
if isa(step,AbstractFloat)
if isa(step, AbstractFloat)
throw(ArgumentError("StepRange should not be used with floating point"))
end
z = zero(step)
Expand All @@ -47,7 +48,7 @@ function Base.steprange_last(start::T, step, stop) where T<:TimeType
last = stop - remain
end
end
last
return last
end

import Base.in
Expand Down
Loading

0 comments on commit 5d41a76

Please sign in to comment.