From 9aa87037759a4a0b9c3ac34ebb2101b4b80048cc Mon Sep 17 00:00:00 2001 From: Sebastian Stock <42280794+sostock@users.noreply.github.com> Date: Mon, 13 May 2024 09:40:52 +0200 Subject: [PATCH] Fix `ustrip` broadcasting when range step has different unit than eltype (#715) --- src/range.jl | 44 ++++++++++++++++++++++++++++++++++++++-- test/runtests.jl | 53 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/src/range.jl b/src/range.jl index 2ebb6b29..8c1921ae 100644 --- a/src/range.jl +++ b/src/range.jl @@ -154,14 +154,54 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::AbstractQu broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::AbstractQuantity, r::AbstractRange) = broadcasted(DefaultArrayStyle{1}(), *, ustrip(x), r) * unit(x) -const BCAST_PROPAGATE_CALLS = Union{typeof(upreferred), typeof(ustrip), Units} +const BCAST_PROPAGATE_CALLS = Union{typeof(upreferred), Units} broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Ref{<:Units}) = r * x[] broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Ref{<:Units}, r::AbstractRange) = x[] * r broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRangeLen) = StepRangeLen{typeof(x(zero(eltype(r))))}(x(r.ref), x(r.step), r.len, r.offset) -broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRange) = StepRange(x(r.start), x(r.step), x(r.stop)) +function broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRange) + start = x(r.start) + au_to = absoluteunit(unit(start)) + step = uconvert(au_to, r.step) + if Base.ArithmeticStyle(start) == Base.ArithmeticRounds() || Base.ArithmeticStyle(step) == Base.ArithmeticRounds() + au_from = absoluteunit(unit(r.start)) + astart = ustrip(au_from, r.start) + astop = ustrip(au_from, r.stop) + len = length(r) + offset = _offset_for_steprangelen(astart, astop, len) + nb = ndigits(max(offset-1, len-offset), base=2, pad=0) + T = promote_type(typeof(start/unit(start)), typeof(step/unit(step))) + unitless_range = Base.steprangelen_hp(T, ustrip(au_to, r[offset]), ustrip(au_to, step), nb, len, offset) + return unitless_range * unit(start) + else + return StepRange(start, step, x(r.stop)) + end +end broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::LinRange) = LinRange(x(r.start), x(r.stop), r.len) broadcasted(::DefaultArrayStyle{1}, ::typeof(|>), r::AbstractRange, x::Ref{<:BCAST_PROPAGATE_CALLS}) = broadcasted(DefaultArrayStyle{1}(), x[], r) +function _offset_for_steprangelen(start, stop, len) + if iszero(start) + return oneunit(len) + elseif iszero(stop) + return len + elseif signbit(start) == signbit(stop) + return abs(start) < abs(stop) ? oneunit(len) : len + else + fstart = Float64(start) + fstop = Float64(stop) + return round(typeof(len), (fstop-len*fstart)/(fstop-fstart)) + end +end + +broadcasted(::DefaultArrayStyle{1}, ::typeof(ustrip), r::StepRangeLen) = + StepRangeLen{typeof(ustrip(zero(eltype(r))))}(ustrip(unit(eltype(r)), r.ref), ustrip(unit(eltype(r)), r.step), r.len, r.offset) +broadcasted(::DefaultArrayStyle{1}, ::typeof(ustrip), r::StepRange) = + ustrip(unit(eltype(r)), r.start):ustrip(unit(eltype(r)), r.step):ustrip(unit(eltype(r)), r.stop) +broadcasted(::DefaultArrayStyle{1}, ::typeof(ustrip), r::LinRange) = + LinRange(ustrip(unit(eltype(r)), r.start), ustrip(unit(eltype(r)), r.stop), r.len) +broadcasted(::DefaultArrayStyle{1}, ::typeof(|>), r::AbstractRange, ::Ref{typeof(ustrip)}) = + broadcasted(DefaultArrayStyle{1}(), ustrip, r) + # for ambiguity resolution broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::AbstractQuantity) where T = broadcasted(DefaultArrayStyle{1}(), *, r, ustrip(x)) * unit(x) diff --git a/test/runtests.jl b/test/runtests.jl index 075c7d7d..b7a2e39c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,7 @@ import Unitful: Ra, °F, °C, K, rad, mrad, °, ms, s, minute, hr, d, yr, Hz, - J, A, N, mol, V, + J, A, N, mol, V, mJ, eV, mW, W, dB, dB_rp, dB_p, dBm, dBV, dBSPL, Decibel, Np, Np_rp, Np_p, Neper, @@ -1325,14 +1325,61 @@ end @test @inferred((1:2:5) .* cm .|> mm) === 10mm:20mm:50mm @test mm.((1:2:5) .* cm) === 10mm:20mm:50mm + @test @inferred(StepRange(1cm,1mm,2cm) .|> km) === (1//100_000)km:(1//1_000_000)km:(2//100_000)km + @test @inferred((1:2:5) .* km .|> upreferred) === 1000m:2000m:5000m @test @inferred((1:2:5)km .|> upreferred) === 1000m:2000m:5000m @test @inferred((1:2:5) .|> upreferred) === 1:2:5 @test @inferred((1.0:2.0:5.0) .* km .|> upreferred) === 1000.0m:2000.0m:5000.0m @test @inferred((1.0:2.0:5.0)km .|> upreferred) === 1000.0m:2000.0m:5000.0m @test @inferred((1.0:2.0:5.0) .|> upreferred) === 1.0:2.0:5.0 + @test @inferred(StepRange(1cm,1mm,2cm) .|> upreferred) === (1//100)m:(1//1000)m:(2//100)m + + # float conversion, dimensionful + for r = [1eV:1eV:5eV, 1eV:1eV:5_000_000eV, 5_000_000eV:-1eV:-1eV, -123_456_789eV:2eV:987_654_321eV, (-11//12)eV:(1//3)eV:(11//4)eV] + for f = (mJ, upreferred) + rf = @inferred(r .|> f) + test_indices = length(r) ≤ 10_000 ? eachindex(r) : rand(eachindex(r), 10_000) + @test eltype(rf) === typeof(f(zero(eltype(r)))) + @test all(≈(rf[i], f(r[i]); rtol=eps()) for i = test_indices) + end + end + + # float conversion from unitless + r = 1:1:360 + rf = °.(r) + @test all(≈(rf[i], °(r[i]); rtol=eps()) for i = eachindex(r)) + + # float conversion to unitless + r = (1:1:360)° + for f = (mrad, NoUnits, upreferred) + rf = f.(r) + @test eltype(rf) === typeof(f(zero(eltype(r)))) + @test all(≈(rf[i], f(r[i]); rtol=eps()) for i = eachindex(r)) + end + + # exact conversion from and to unitless + @test rad.(1:1:360) === (1:1:360)rad + @test mrad.(1:1:360) === (1_000:1_000:360_000)mrad + @test upreferred.(1:1:360) === 1:1:360 + @test NoUnits.((1:1:360)rad) === 1:1:360 + @test upreferred.((1:1:360)rad) === 1:1:360 + @test NoUnits.((1:2:5)mrad) === 1//1000:1//500:1//200 + @test upreferred.((1:2:5)mrad) === 1//1000:1//500:1//200 + @test @inferred((1:2:5) .* cm .|> mm .|> ustrip) === 10:20:50 @test @inferred((1f0:2f0:5f0) .* cm .|> mm .|> ustrip) === 10f0:20f0:50f0 + @test @inferred(StepRange{typeof(1m),typeof(1cm)}(1m,1cm,2m) .|> ustrip) === 1:1//100:2 + @test @inferred(StepRangeLen{typeof(1f0m)}(1.0m, 1.0cm, 101) .|> ustrip) === StepRangeLen{Float32}(1.0, 0.01, 101) + @test @inferred(StepRangeLen{typeof(1.0m)}(Base.TwicePrecision(1.0m), Base.TwicePrecision(1.0cm), 101) .|> ustrip) === StepRangeLen{Float64}(Base.TwicePrecision(1.0), Base.TwicePrecision(0.01), 101) + @test @inferred((1:0.1:1.0) .|> ustrip) == 1:0.1:1.0 + @test @inferred((1m:0.1m:1.0m) .|> ustrip) == 1:0.1:1.0 + @test @inferred(StepRange{typeof(0m),typeof(1cm)}(1m,1cm,2m) .|> ustrip) === 1:1//100:2 + @test @inferred(StepRangeLen{typeof(1f0m)}(1.0m, 1.0cm, 101) .|> ustrip) === StepRangeLen{Float32}(1.0, 0.01, 101) + @test @inferred(StepRangeLen{typeof(1.0m)}(Base.TwicePrecision(1.0m), Base.TwicePrecision(1.0cm), 101) .|> ustrip) === StepRangeLen{Float64}(Base.TwicePrecision(1.0), Base.TwicePrecision(0.01), 101) + @test @inferred(StepRangeLen{typeof(1.0mm)}(Base.TwicePrecision(1.0m), Base.TwicePrecision(1.0cm), 101) .|> ustrip) === 1000.0:10.0:2000.0 + @test ustrip.(1:0.1:1.0) == 1:0.1:1.0 + @test ustrip.(1m:0.1m:1.0m) == 1:0.1:1.0 end @testset ">> quantities and non-quantities" begin @test range(1, step=1m/mm, length=5) == 1:1000:4001 @@ -1400,10 +1447,6 @@ end @test_throws ArgumentError range(step=1m, length=5) end end - @testset ">> broadcast ustrip" begin - @test ustrip.(1:0.1:1.0) == 1:0.1:1.0 - @test ustrip.(1m:0.1m:1.0m) == 1:0.1:1.0 - end end @testset "> Arrays" begin @testset ">> Array multiplication" begin