From ea5c9cba8c97e1d221d4b48afe1d2d965c48ee64 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 14 Apr 2023 13:34:10 +0200 Subject: [PATCH] Make `normalize` work for `Number`s (#49342) --- stdlib/LinearAlgebra/src/generic.jl | 7 ++--- stdlib/LinearAlgebra/test/generic.jl | 28 +++-------------- test/testhelpers/DualNumbers.jl | 46 ++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 29 deletions(-) create mode 100644 test/testhelpers/DualNumbers.jl diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index 0c947936dee6b..c66f59838e8ba 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -1804,21 +1804,18 @@ function normalize!(a::AbstractArray, p::Real=2) __normalize!(a, nrm) end -@inline function __normalize!(a::AbstractArray, nrm::Real) +@inline function __normalize!(a::AbstractArray, nrm) # The largest positive floating point number whose inverse is less than infinity δ = inv(prevfloat(typemax(nrm))) - if nrm ≥ δ # Safe to multiply with inverse invnrm = inv(nrm) rmul!(a, invnrm) - else # scale elements to avoid overflow εδ = eps(one(nrm))/δ rmul!(a, εδ) rmul!(a, inv(nrm*εδ)) end - - a + return a end """ diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index 108d3aec8f069..3ebaf38e84945 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -12,6 +12,8 @@ using .Main.Quaternions isdefined(Main, :OffsetArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "OffsetArrays.jl")) using .Main.OffsetArrays +isdefined(Main, :DualNumbers) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "DualNumbers.jl")) +using .Main.DualNumbers Random.seed!(123) @@ -78,30 +80,7 @@ n = 5 # should be odd end @testset "det with nonstandard Number type" begin - struct MyDual{T<:Real} <: Real - val::T - eps::T - end - Base.:+(x::MyDual, y::MyDual) = MyDual(x.val + y.val, x.eps + y.eps) - Base.:*(x::MyDual, y::MyDual) = MyDual(x.val * y.val, x.eps * y.val + y.eps * x.val) - Base.:/(x::MyDual, y::MyDual) = x.val / y.val - Base.:(==)(x::MyDual, y::MyDual) = x.val == y.val && x.eps == y.eps - Base.zero(::MyDual{T}) where {T} = MyDual(zero(T), zero(T)) - Base.zero(::Type{MyDual{T}}) where {T} = MyDual(zero(T), zero(T)) - Base.one(::MyDual{T}) where {T} = MyDual(one(T), zero(T)) - Base.one(::Type{MyDual{T}}) where {T} = MyDual(one(T), zero(T)) - # the following line is required for BigFloat, IDK why it doesn't work via - # promote_rule like for all other types - Base.promote_type(::Type{MyDual{BigFloat}}, ::Type{BigFloat}) = MyDual{BigFloat} - Base.promote_rule(::Type{MyDual{T}}, ::Type{S}) where {T,S<:Real} = - MyDual{promote_type(T, S)} - Base.promote_rule(::Type{MyDual{T}}, ::Type{MyDual{S}}) where {T,S} = - MyDual{promote_type(T, S)} - Base.convert(::Type{MyDual{T}}, x::MyDual) where {T} = - MyDual(convert(T, x.val), convert(T, x.eps)) - if elty <: Real - @test det(triu(MyDual.(A, zero(A)))) isa MyDual - end + elty <: Real && @test det(Dual.(triu(A), zero(A))) isa Dual end end @@ -390,6 +369,7 @@ end [1.0 2.0 3.0; 4.0 5.0 6.0], # 2-dim rand(1,2,3), # higher dims rand(1,2,3,4), + Dual.(randn(2,3), randn(2,3)), OffsetArray([-1,0], (-2,)) # no index 1 ) @test normalize(arr) == normalize!(copy(arr)) diff --git a/test/testhelpers/DualNumbers.jl b/test/testhelpers/DualNumbers.jl new file mode 100644 index 0000000000000..9f62e3bf0d429 --- /dev/null +++ b/test/testhelpers/DualNumbers.jl @@ -0,0 +1,46 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +module DualNumbers + +export Dual + +# Dual numbers type with minimal interface +# example of a (real) number type that subtypes Number, but not Real. +# Can be used to test generic linear algebra functions. + +struct Dual{T<:Real} <: Number + val::T + eps::T +end +Base.:+(x::Dual, y::Dual) = Dual(x.val + y.val, x.eps + y.eps) +Base.:-(x::Dual, y::Dual) = Dual(x.val - y.val, x.eps - y.eps) +Base.:*(x::Dual, y::Dual) = Dual(x.val * y.val, x.eps * y.val + y.eps * x.val) +Base.:*(x::Number, y::Dual) = Dual(x*y.val, x*y.eps) +Base.:*(x::Dual, y::Number) = Dual(x.val*y, x.eps*y) +Base.:/(x::Dual, y::Dual) = Dual(x.val / y.val, (x.eps*y.val - x.val*y.eps)/(y.val*y.val)) + +Base.:(==)(x::Dual, y::Dual) = x.val == y.val && x.eps == y.eps + +Base.promote_rule(::Type{Dual{T}}, ::Type{T}) where {T} = Dual{T} +Base.promote_rule(::Type{Dual{T}}, ::Type{S}) where {T,S<:Real} = Dual{promote_type(T, S)} +Base.promote_rule(::Type{Dual{T}}, ::Type{Dual{S}}) where {T,S} = Dual{promote_type(T, S)} + +Base.convert(::Type{Dual{T}}, x::Dual{T}) where {T} = x +Base.convert(::Type{Dual{T}}, x::Dual) where {T} = Dual(convert(T, x.val), convert(T, x.eps)) +Base.convert(::Type{Dual{T}}, x::Real) where {T} = Dual(convert(T, x), zero(T)) + +Base.float(x::Dual) = Dual(float(x.val), float(x.eps)) +# the following two methods are needed for normalize (to check for potential overflow) +Base.typemax(x::Dual) = Dual(typemax(x.val), zero(x.eps)) +Base.prevfloat(x::Dual{<:AbstractFloat}) = prevfloat(x.val) + +Base.abs2(x::Dual) = x*x +Base.abs(x::Dual) = sqrt(abs2(x)) +Base.sqrt(x::Dual) = Dual(sqrt(x.val), x.eps/(2sqrt(x.val))) + +Base.isless(x::Dual, y::Dual) = x.val < y.val +Base.isless(x::Real, y::Dual) = x < y.val +Base.isinf(x::Dual) = isinf(x.val) & isfinite(x.eps) +Base.real(x::Dual) = x # since we curently only consider Dual{<:Real} + +end # module