Skip to content

Commit

Permalink
add length type parameter to StepRangeLen
Browse files Browse the repository at this point in the history
Also be more careful about using additive identity instead of
multiplicative, and be more consistent about types in a few places.

Fixes #41517
  • Loading branch information
vtjnash committed Jul 16, 2021
1 parent 2893de7 commit 72c1e7f
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 116 deletions.
21 changes: 12 additions & 9 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1121,19 +1121,20 @@ end

## scalar-range broadcast operations ##
# DefaultArrayStyle and \ are not available at the time of range.jl
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::OrdinalRange) = r
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen) = r
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange) = r
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange) = r

broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange) = range(-first(r), step=-step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), -last(r), step=-step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r))

broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r))
# For #18336 we need to prevent promotion of the step type:
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::OrdinalRange, x::Real) = range(first(r) + x, last(r) + x, step=step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::Real) = range(x + first(r), x + last(r), step=step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, last(r) + x)
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), x + last(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T =
StepRangeLen{typeof(T(r.ref)+x)}(r.ref + x, r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T =
Expand All @@ -1142,9 +1143,11 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange, x::Number) = LinRa
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::LinRange) = LinRange(x + r.start, x + r.stop, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2

broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) = range(first(r)-x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r)-x, step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x-first(r), step=-step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r) - x, step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x - first(r), step=-step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange, x::Real) = range(first(r) - x, last(r) - x, step=step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Real, r::OrdinalRange) = range(x - first(r), x - last(r), step=-step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Real) = range(first(r) - x, last(r) - x)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T =
StepRangeLen{typeof(T(r.ref)-x)}(r.ref - x, r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T =
Expand Down
111 changes: 64 additions & 47 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
_colon(::Ordered, ::Any, start::T, step, stop::T) where {T} = StepRange(start, step, stop)
# for T<:Union{Float16,Float32,Float64} see twiceprecision.jl
_colon(::Ordered, ::ArithmeticRounds, start::T, step, stop::T) where {T} =
StepRangeLen(start, step, floor(Int, (stop-start)/step)+1)
StepRangeLen(start, step, floor(Integer, (stop-start)/step)+1)
_colon(::Any, ::Any, start::T, step, stop::T) where {T} =
StepRangeLen(start, step, floor(Int, (stop-start)/step)+1)
StepRangeLen(start, step, floor(Integer, (stop-start)/step)+1)

"""
(:)(start, [step], stop)
Expand Down Expand Up @@ -415,8 +415,9 @@ oneto(r) = OneTo(r)
## Step ranges parameterized by length

"""
StepRangeLen{T,R,S}(ref::R, step::S, len, [offset=1]) where {T,R,S}
StepRangeLen( ref::R, step::S, len, [offset=1]) where { R,S}
StepRangeLen( ref::R, step::S, len, [offset=1]) where { R,S}
StepRangeLen{T,R,S}( ref::R, step::S, len, [offset=1]) where {T,R,S}
StepRangeLen{T,R,S,L}(ref::R, step::S, len, [offset=1]) where {T,R,S,L}
A range `r` where `r[i]` produces values of type `T` (in the second
form, `T` is deduced automatically), parameterized by a `ref`erence
Expand All @@ -426,26 +427,30 @@ value `r[1]`, but alternatively you can supply it as the value of
with `TwicePrecision` this can be used to implement ranges that are
free of roundoff error.
"""
struct StepRangeLen{T,R,S} <: AbstractRange{T}
struct StepRangeLen{T,R,S,L} <: AbstractRange{T}
ref::R # reference value (might be smallest-magnitude value in the range)
step::S # step value
len::Int # length of the range
offset::Int # the index of ref
len::L # length of the range
offset::L # the index of ref

function StepRangeLen{T,R,S}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S}
function StepRangeLen{T,R,S,L}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S,L}
if T <: Integer && !isinteger(ref + step)
throw(ArgumentError("StepRangeLen{<:Integer} cannot have non-integer step"))
end
len = convert(L, len)
len >= 0 || throw(ArgumentError("length cannot be negative, got $len"))
1 <= offset <= max(1,len) || throw(ArgumentError("StepRangeLen: offset must be in [1,$len], got $offset"))
new(ref, step, len, offset)
offset = convert(L, offset)
1 <= offset <= max(1, len) || throw(ArgumentError("StepRangeLen: offset must be in [1,$len], got $offset"))
return new(ref, step, len, offset)
end
end

StepRangeLen{T,R,S}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S} =
StepRangeLen{T,R,S,promote_type(Int,typeof(len))}(ref, step, len, offset)
StepRangeLen(ref::R, step::S, len::Integer, offset::Integer = 1) where {R,S} =
StepRangeLen{typeof(ref+zero(step)),R,S}(ref, step, len, offset)
StepRangeLen{typeof(ref+zero(step)),R,S,promote_type(Int,typeof(len))}(ref, step, len, offset)
StepRangeLen{T}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S} =
StepRangeLen{T,R,S}(ref, step, len, offset)
StepRangeLen{T,R,S,promote_type(Int,typeof(len))}(ref, step, len, offset)

## range with computed step

Expand Down Expand Up @@ -621,6 +626,7 @@ step(r::StepRangeLen) = r.step
step(r::StepRangeLen{T}) where {T<:AbstractFloat} = T(r.step)
step(r::LinRange) = (last(r)-first(r))/r.lendiv

# high-precision step
step_hp(r::StepRangeLen) = r.step
step_hp(r::AbstractRange) = step(r)

Expand Down Expand Up @@ -648,7 +654,7 @@ function checked_length(r::OrdinalRange{T}) where T
diff = checked_sub(stop, start)
end
a = Integer(div(diff, s))
return checked_add(a, one(a))
return checked_add(a, oneunit(a))
end

function checked_length(r::AbstractUnitRange{T}) where T
Expand All @@ -657,7 +663,7 @@ function checked_length(r::AbstractUnitRange{T}) where T
return Integer(first(r) - first(r))
end
a = Integer(checked_add(checked_sub(last(r), first(r))))
return checked_add(a, one(a))
return checked_add(a, oneunit(a))
end

function length(r::OrdinalRange{T}) where T
Expand All @@ -675,14 +681,14 @@ function length(r::OrdinalRange{T}) where T
diff = stop - start
end
a = Integer(div(diff, s))
return a + one(a)
return a + oneunit(a)
end


function length(r::AbstractUnitRange{T}) where T
@_inline_meta
a = Integer(last(r) - first(r)) # even when isempty, by construction (with overflow)
return a + one(a)
return a + oneunit(a)
end

length(r::OneTo) = Integer(r.stop - zero(r.stop))
Expand Down Expand Up @@ -710,7 +716,7 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
else
a = div(unsigned(diff), s) % typeof(diff)
end
return Integer(a) + one(a)
return Integer(a) + oneunit(a)
end
function checked_length(r::OrdinalRange{T}) where T<:bigints
s = step(r)
Expand All @@ -729,7 +735,7 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
else
a = div(checked_sub(start, stop), -s)
end
return checked_add(a, one(a))
return checked_add(a, oneunit(a))
end
end

Expand Down Expand Up @@ -803,7 +809,13 @@ copy(r::AbstractRange) = r

## iteration

function iterate(r::Union{LinRange,StepRangeLen}, i::Int=1)
function iterate(r::StepRangeLen, i::Integer=1)
@_inline_meta
length(r) < i && return nothing
unsafe_getindex(r, i), i + 1
end

function iterate(r::LinRange, i::Int=1)
@_inline_meta
length(r) < i && return nothing
unsafe_getindex(r, i), i + 1
Expand Down Expand Up @@ -897,7 +909,7 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
@boundscheck checkbounds(r, s)

if T === Bool
range(first(s) ? first(r) : last(r), length = Int(last(s)))
range(first(s) ? first(r) : last(r), length = Integer(last(s)))
else
f = first(r)
st = oftype(f, f + first(s)-1)
Expand All @@ -916,7 +928,7 @@ function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
@boundscheck checkbounds(r, s)

if T === Bool
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Int(last(s)))
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Integer(last(s)))
else
st = oftype(first(r), first(r) + s.start-1)
return range(st, step=step(s), length=length(s))
Expand Down Expand Up @@ -949,24 +961,29 @@ function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)

len = length(s)
sstep = step_hp(s)
rstep = step_hp(r)
L = typeof(len)
if S === Bool
if length(s) == 0
return StepRangeLen{T}(first(r), step(r), 0, 1)
elseif length(s) == 1
rstep *= one(sstep)
if len == 0
return StepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
elseif len == 1
if first(s)
return StepRangeLen{T}(first(r), step(r), 1, 1)
return StepRangeLen{T}(first(r), rstep, oneunit(L), oneunit(L))
else
return StepRangeLen{T}(first(r), step(r), 0, 1)
return StepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
end
else # length(s) == 2
return StepRangeLen{T}(last(r), step(r), 1, 1)
else # len == 2
return StepRangeLen{T}(last(r), rstep, oneunit(L), oneunit(L))
end
else
# Find closest approach to offset by s
ind = LinearIndices(s)
offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind))
ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s))
return StepRangeLen{T}(ref, r.step*step(s), length(s), offset)
offset = L(max(min(1 + round(L, (r.offset - first(s))/sstep), last(ind)), first(ind)))
ref = _getindex_hiprec(r, first(s) + (offset-1)*sstep)
return StepRangeLen{T}(ref, rstep*sstep, len, offset)
end
end

Expand Down Expand Up @@ -1153,8 +1170,8 @@ issubset(r::AbstractUnitRange{<:Integer}, s::AbstractUnitRange{<:Integer}) =
## linear operations on ranges ##

-(r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
-(r::StepRangeLen{T,R,S}) where {T,R,S} =
StepRangeLen{T,R,S}(-r.ref, -r.step, length(r), r.offset)
-(r::StepRangeLen{T,R,S,L}) where {T,R,S,L} =
StepRangeLen{T,R,S,L}(-r.ref, -r.step, r.len, r.offset)
function -(r::LinRange)
start = -r.start
LinRange{typeof(start)}(start, -r.stop, length(r))
Expand Down Expand Up @@ -1206,20 +1223,20 @@ StepRange(r::AbstractUnitRange{T}) where {T} =
StepRange{T,T}(first(r), step(r), last(r))
(StepRange{T1,T2} where T1)(r::AbstractRange) where {T2} = StepRange{eltype(r),T2}(r)

promote_rule(::Type{StepRangeLen{T1,R1,S1}},::Type{StepRangeLen{T2,R2,S2}}) where {T1,T2,R1,R2,S1,S2} =
promote_rule(::Type{StepRangeLen{T1,R1,S1,L1}},::Type{StepRangeLen{T2,R2,S2,L2}}) where {T1,T2,R1,R2,S1,S2,L1,L2} =
el_same(promote_type(T1,T2),
StepRangeLen{T1,promote_type(R1,R2),promote_type(S1,S2)},
StepRangeLen{T2,promote_type(R1,R2),promote_type(S1,S2)})
StepRangeLen{T,R,S}(r::StepRangeLen{T,R,S}) where {T,R,S} = r
StepRangeLen{T,R,S}(r::StepRangeLen) where {T,R,S} =
StepRangeLen{T,R,S}(convert(R, r.ref), convert(S, r.step), length(r), r.offset)
StepRangeLen{T1,promote_type(R1,R2),promote_type(S1,S2),promote_type(L1,L2)},
StepRangeLen{T2,promote_type(R1,R2),promote_type(S1,S2),promote_type(L1,L2)})
StepRangeLen{T,R,S,L}(r::StepRangeLen{T,R,S,L}) where {T,R,S,L} = r
StepRangeLen{T,R,S,L}(r::StepRangeLen) where {T,R,S,L} =
StepRangeLen{T,R,S,L}(convert(R, r.ref), convert(S, r.step), convert(L, r.len), convert(L, r.offset))
StepRangeLen{T}(r::StepRangeLen) where {T} =
StepRangeLen(convert(T, r.ref), convert(T, r.step), length(r), r.offset)
StepRangeLen(convert(T, r.ref), convert(T, r.step), r.len, r.offset)

promote_rule(a::Type{StepRangeLen{T,R,S}}, ::Type{OR}) where {T,R,S,OR<:AbstractRange} =
promote_rule(a, StepRangeLen{eltype(OR), eltype(OR), eltype(OR)})
StepRangeLen{T,R,S}(r::AbstractRange) where {T,R,S} =
StepRangeLen{T,R,S}(R(first(r)), S(step(r)), length(r))
promote_rule(a::Type{StepRangeLen{T,R,S,L}}, ::Type{OR}) where {T,R,S,L,OR<:AbstractRange} =
promote_rule(a, StepRangeLen{eltype(OR), eltype(OR), eltype(OR), Int})
StepRangeLen{T,R,S,L}(r::AbstractRange) where {T,R,S,L} =
StepRangeLen{T,R,S,L}(R(first(r)), S(step(r)), length(r))
StepRangeLen{T}(r::AbstractRange) where {T} =
StepRangeLen(T(first(r)), T(step(r)), length(r))
StepRangeLen(r::AbstractRange) = StepRangeLen{eltype(r)}(r)
Expand All @@ -1233,8 +1250,8 @@ LinRange(r::AbstractRange{T}) where {T} = LinRange{T}(r)
promote_rule(a::Type{LinRange{T}}, ::Type{OR}) where {T,OR<:OrdinalRange} =
promote_rule(a, LinRange{eltype(OR)})

promote_rule(::Type{LinRange{L}}, b::Type{StepRangeLen{T,R,S}}) where {L,T,R,S} =
promote_rule(StepRangeLen{L,L,L}, b)
promote_rule(::Type{LinRange{A}}, b::Type{StepRangeLen{T,R,S,L}}) where {A,T,R,S,L} =
promote_rule(StepRangeLen{A,A,A,Int}, b)

## concatenation ##

Expand All @@ -1261,7 +1278,7 @@ function _reverse(r::StepRangeLen, ::Colon)
# invalid. As `reverse(r)` is also empty, any offset would work so we keep
# `r.offset`
offset = isempty(r) ? r.offset : length(r)-r.offset+1
StepRangeLen(r.ref, -r.step, length(r), offset)
return typeof(r)(r.ref, -r.step, length(r), offset)
end
_reverse(r::LinRange{T}, ::Colon) where {T} = LinRange{T}(r.stop, r.start, length(r))

Expand Down
Loading

0 comments on commit 72c1e7f

Please sign in to comment.