Skip to content

Commit

Permalink
Fix ustrip broadcasting when range step has different unit than elt…
Browse files Browse the repository at this point in the history
…ype (#715)
  • Loading branch information
sostock authored May 13, 2024
1 parent 6bfc193 commit 9aa8703
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
44 changes: 42 additions & 2 deletions src/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,54 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::AbstractQu
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::AbstractQuantity, r::AbstractRange) =
broadcasted(DefaultArrayStyle{1}(), *, ustrip(x), r) * unit(x)

const BCAST_PROPAGATE_CALLS = Union{typeof(upreferred), typeof(ustrip), Units}
const BCAST_PROPAGATE_CALLS = Union{typeof(upreferred), Units}
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Ref{<:Units}) = r * x[]
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Ref{<:Units}, r::AbstractRange) = x[] * r
broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRangeLen) = StepRangeLen{typeof(x(zero(eltype(r))))}(x(r.ref), x(r.step), r.len, r.offset)
broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRange) = StepRange(x(r.start), x(r.step), x(r.stop))
function broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRange)
start = x(r.start)
au_to = absoluteunit(unit(start))
step = uconvert(au_to, r.step)
if Base.ArithmeticStyle(start) == Base.ArithmeticRounds() || Base.ArithmeticStyle(step) == Base.ArithmeticRounds()
au_from = absoluteunit(unit(r.start))
astart = ustrip(au_from, r.start)
astop = ustrip(au_from, r.stop)
len = length(r)
offset = _offset_for_steprangelen(astart, astop, len)
nb = ndigits(max(offset-1, len-offset), base=2, pad=0)
T = promote_type(typeof(start/unit(start)), typeof(step/unit(step)))
unitless_range = Base.steprangelen_hp(T, ustrip(au_to, r[offset]), ustrip(au_to, step), nb, len, offset)
return unitless_range * unit(start)
else
return StepRange(start, step, x(r.stop))
end
end
broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::LinRange) = LinRange(x(r.start), x(r.stop), r.len)
broadcasted(::DefaultArrayStyle{1}, ::typeof(|>), r::AbstractRange, x::Ref{<:BCAST_PROPAGATE_CALLS}) = broadcasted(DefaultArrayStyle{1}(), x[], r)

function _offset_for_steprangelen(start, stop, len)
if iszero(start)
return oneunit(len)
elseif iszero(stop)
return len
elseif signbit(start) == signbit(stop)
return abs(start) < abs(stop) ? oneunit(len) : len
else
fstart = Float64(start)
fstop = Float64(stop)
return round(typeof(len), (fstop-len*fstart)/(fstop-fstart))
end
end

broadcasted(::DefaultArrayStyle{1}, ::typeof(ustrip), r::StepRangeLen) =
StepRangeLen{typeof(ustrip(zero(eltype(r))))}(ustrip(unit(eltype(r)), r.ref), ustrip(unit(eltype(r)), r.step), r.len, r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(ustrip), r::StepRange) =
ustrip(unit(eltype(r)), r.start):ustrip(unit(eltype(r)), r.step):ustrip(unit(eltype(r)), r.stop)
broadcasted(::DefaultArrayStyle{1}, ::typeof(ustrip), r::LinRange) =
LinRange(ustrip(unit(eltype(r)), r.start), ustrip(unit(eltype(r)), r.stop), r.len)
broadcasted(::DefaultArrayStyle{1}, ::typeof(|>), r::AbstractRange, ::Ref{typeof(ustrip)}) =
broadcasted(DefaultArrayStyle{1}(), ustrip, r)

# for ambiguity resolution
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::AbstractQuantity) where T =
broadcasted(DefaultArrayStyle{1}(), *, r, ustrip(x)) * unit(x)
Expand Down
53 changes: 48 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Unitful:
Ra, °F, °C, K,
rad, mrad, °,
ms, s, minute, hr, d, yr, Hz,
J, A, N, mol, V,
J, A, N, mol, V, mJ, eV,
mW, W,
dB, dB_rp, dB_p, dBm, dBV, dBSPL, Decibel,
Np, Np_rp, Np_p, Neper,
Expand Down Expand Up @@ -1325,14 +1325,61 @@ end

@test @inferred((1:2:5) .* cm .|> mm) === 10mm:20mm:50mm
@test mm.((1:2:5) .* cm) === 10mm:20mm:50mm
@test @inferred(StepRange(1cm,1mm,2cm) .|> km) === (1//100_000)km:(1//1_000_000)km:(2//100_000)km

@test @inferred((1:2:5) .* km .|> upreferred) === 1000m:2000m:5000m
@test @inferred((1:2:5)km .|> upreferred) === 1000m:2000m:5000m
@test @inferred((1:2:5) .|> upreferred) === 1:2:5
@test @inferred((1.0:2.0:5.0) .* km .|> upreferred) === 1000.0m:2000.0m:5000.0m
@test @inferred((1.0:2.0:5.0)km .|> upreferred) === 1000.0m:2000.0m:5000.0m
@test @inferred((1.0:2.0:5.0) .|> upreferred) === 1.0:2.0:5.0
@test @inferred(StepRange(1cm,1mm,2cm) .|> upreferred) === (1//100)m:(1//1000)m:(2//100)m

# float conversion, dimensionful
for r = [1eV:1eV:5eV, 1eV:1eV:5_000_000eV, 5_000_000eV:-1eV:-1eV, -123_456_789eV:2eV:987_654_321eV, (-11//12)eV:(1//3)eV:(11//4)eV]
for f = (mJ, upreferred)
rf = @inferred(r .|> f)
test_indices = length(r) 10_000 ? eachindex(r) : rand(eachindex(r), 10_000)
@test eltype(rf) === typeof(f(zero(eltype(r))))
@test all((rf[i], f(r[i]); rtol=eps()) for i = test_indices)
end
end

# float conversion from unitless
r = 1:1:360
rf = °.(r)
@test all((rf[i], °(r[i]); rtol=eps()) for i = eachindex(r))

# float conversion to unitless
r = (1:1:360
for f = (mrad, NoUnits, upreferred)
rf = f.(r)
@test eltype(rf) === typeof(f(zero(eltype(r))))
@test all((rf[i], f(r[i]); rtol=eps()) for i = eachindex(r))
end

# exact conversion from and to unitless
@test rad.(1:1:360) === (1:1:360)rad
@test mrad.(1:1:360) === (1_000:1_000:360_000)mrad
@test upreferred.(1:1:360) === 1:1:360
@test NoUnits.((1:1:360)rad) === 1:1:360
@test upreferred.((1:1:360)rad) === 1:1:360
@test NoUnits.((1:2:5)mrad) === 1//1000:1//500:1//200
@test upreferred.((1:2:5)mrad) === 1//1000:1//500:1//200

@test @inferred((1:2:5) .* cm .|> mm .|> ustrip) === 10:20:50
@test @inferred((1f0:2f0:5f0) .* cm .|> mm .|> ustrip) === 10f0:20f0:50f0
@test @inferred(StepRange{typeof(1m),typeof(1cm)}(1m,1cm,2m) .|> ustrip) === 1:1//100:2
@test @inferred(StepRangeLen{typeof(1f0m)}(1.0m, 1.0cm, 101) .|> ustrip) === StepRangeLen{Float32}(1.0, 0.01, 101)
@test @inferred(StepRangeLen{typeof(1.0m)}(Base.TwicePrecision(1.0m), Base.TwicePrecision(1.0cm), 101) .|> ustrip) === StepRangeLen{Float64}(Base.TwicePrecision(1.0), Base.TwicePrecision(0.01), 101)
@test @inferred((1:0.1:1.0) .|> ustrip) == 1:0.1:1.0
@test @inferred((1m:0.1m:1.0m) .|> ustrip) == 1:0.1:1.0
@test @inferred(StepRange{typeof(0m),typeof(1cm)}(1m,1cm,2m) .|> ustrip) === 1:1//100:2
@test @inferred(StepRangeLen{typeof(1f0m)}(1.0m, 1.0cm, 101) .|> ustrip) === StepRangeLen{Float32}(1.0, 0.01, 101)
@test @inferred(StepRangeLen{typeof(1.0m)}(Base.TwicePrecision(1.0m), Base.TwicePrecision(1.0cm), 101) .|> ustrip) === StepRangeLen{Float64}(Base.TwicePrecision(1.0), Base.TwicePrecision(0.01), 101)
@test @inferred(StepRangeLen{typeof(1.0mm)}(Base.TwicePrecision(1.0m), Base.TwicePrecision(1.0cm), 101) .|> ustrip) === 1000.0:10.0:2000.0
@test ustrip.(1:0.1:1.0) == 1:0.1:1.0
@test ustrip.(1m:0.1m:1.0m) == 1:0.1:1.0
end
@testset ">> quantities and non-quantities" begin
@test range(1, step=1m/mm, length=5) == 1:1000:4001
Expand Down Expand Up @@ -1400,10 +1447,6 @@ end
@test_throws ArgumentError range(step=1m, length=5)
end
end
@testset ">> broadcast ustrip" begin
@test ustrip.(1:0.1:1.0) == 1:0.1:1.0
@test ustrip.(1m:0.1m:1.0m) == 1:0.1:1.0
end
end
@testset "> Arrays" begin
@testset ">> Array multiplication" begin
Expand Down

0 comments on commit 9aa8703

Please sign in to comment.