From 2a3afb3f39fb63feb0198fa049bc259295879cdc Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Wed, 23 Mar 2022 13:14:10 -0400 Subject: [PATCH 1/6] fallback randn/randexp for AbstractFloat --- NEWS.md | 2 ++ stdlib/Random/src/normal.jl | 13 +++++++++++++ stdlib/Random/test/runtests.jl | 32 ++++++++++++++++++++++++++------ 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/NEWS.md b/NEWS.md index 6e0b00c92f041..1d402d97f9781 100644 --- a/NEWS.md +++ b/NEWS.md @@ -65,6 +65,8 @@ Standard library changes #### Random +* `randn` and `randexp` now work for any `AbstractFloat` type defining `rand` ([#44714]). + #### REPL #### SparseArrays diff --git a/stdlib/Random/src/normal.jl b/stdlib/Random/src/normal.jl index d7fe94f58fa57..1096f097f565a 100644 --- a/stdlib/Random/src/normal.jl +++ b/stdlib/Random/src/normal.jl @@ -90,6 +90,16 @@ randn(rng::AbstractRNG, ::Type{Complex{T}}) where {T<:AbstractFloat} = Complex{T}(SQRT_HALF * randn(rng, T), SQRT_HALF * randn(rng, T)) +### fallback randn for float types defining rand: +function randn(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat} + # Marsaglia polar variant of Box–Muller transform: + while true + x, y = 2rand(rng, T)-1, 2rand(rng, T)-1 + 0 < (s = x^2 + y^2) < 1 || continue + return x * sqrt(-2log(s)/s) # and/or y * sqrt(...) + end +end + ## randexp """ @@ -137,6 +147,9 @@ end end end +### fallback randexp for float types defining rand: +randexp(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat} = + -log(rand(rng, T)) ## arrays & other scalar methods diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index c8be4c95cdaf2..6333ad43e1328 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -307,9 +307,32 @@ let a = [rand(RandomDevice(), UInt128) for i=1:10] @test reduce(|, a)>>>64 != 0 end +# wrapper around Float64 to check fallback random generators +struct FakeFloat64 <: AbstractFloat + x::Float64 +end +Base.rand(rng::AbstractRNG, ::Random.SamplerTrivial{Random.CloseOpen01{FakeFloat64}}) = FakeFloat64(rand(rng)) +for f in (:sqrt, :log, :one, :zero, :abs, :+, :-) + @eval Base.$f(x::FakeFloat64) = FakeFloat64($f(x.x)) +end +for f in (:+, :-, :*, :/) + @eval begin + Base.$f(x::FakeFloat64, y::FakeFloat64) = FakeFloat64($f(x.x,y.x)) + Base.$f(x::FakeFloat64, y::Real) = FakeFloat64($f(x.x,y)) + Base.$f(x::Real, y::FakeFloat64) = FakeFloat64($f(x,y.x)) + end +end +for f in (:<, :<=, :>, :>=, :(==), :(!=)) + @eval begin + Base.$f(x::FakeFloat64, y::FakeFloat64) = $f(x.x,y.x) + Base.$f(x::FakeFloat64, y::Real) = $f(x.x,y) + Base.$f(x::Real, y::FakeFloat64) = $f(x,y.x) + end +end + # test all rand APIs for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()]) - ftypes = [Float16, Float32, Float64] + ftypes = [Float16, Float32, Float64, FakeFloat64, BigFloat] cftypes = [ComplexF16, ComplexF32, ComplexF64, ftypes...] types = [Bool, Char, BigFloat, Base.BitInteger_types..., ftypes...] randset = Set(rand(Int, 20)) @@ -406,15 +429,12 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()]) rand!(rng..., BitMatrix(undef, 2, 3)) ::BitArray{2} # Test that you cannot call randn or randexp with non-Float types. - for r in [randn, randexp, randn!, randexp!] - local r + for r in [randn, randexp] @test_throws MethodError r(Int) @test_throws MethodError r(Int32) @test_throws MethodError r(Bool) @test_throws MethodError r(String) - @test_throws MethodError r(AbstractFloat) - # TODO(#17627): Consider adding support for randn(BigFloat) and removing this test. - @test_throws MethodError r(BigFloat) + @test_throws ArgumentError r(AbstractFloat) @test_throws MethodError r(Int64, (2,3)) @test_throws MethodError r(String, 1) From d5770b8353f3074191692e4734bae19deb1e19f0 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Wed, 23 Mar 2022 14:04:19 -0400 Subject: [PATCH 2/6] randexp should never return Inf --- stdlib/Random/src/normal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/Random/src/normal.jl b/stdlib/Random/src/normal.jl index 1096f097f565a..e5ce5a72244ad 100644 --- a/stdlib/Random/src/normal.jl +++ b/stdlib/Random/src/normal.jl @@ -149,7 +149,7 @@ end ### fallback randexp for float types defining rand: randexp(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat} = - -log(rand(rng, T)) + -log(1-rand(rng, T)) ## arrays & other scalar methods From 557e89d576fe72b9d48a2f0cacb40e7166144b2f Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Wed, 23 Mar 2022 14:10:51 -0400 Subject: [PATCH 3/6] shorten --- stdlib/Random/src/normal.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stdlib/Random/src/normal.jl b/stdlib/Random/src/normal.jl index e5ce5a72244ad..5a7448ab0b3d9 100644 --- a/stdlib/Random/src/normal.jl +++ b/stdlib/Random/src/normal.jl @@ -95,8 +95,7 @@ function randn(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat} # Marsaglia polar variant of Box–Muller transform: while true x, y = 2rand(rng, T)-1, 2rand(rng, T)-1 - 0 < (s = x^2 + y^2) < 1 || continue - return x * sqrt(-2log(s)/s) # and/or y * sqrt(...) + 0 < (s = x^2 + y^2) < 1 && return x * sqrt(-2log(s)/s) end end From aa4d95e23d5696abcfffa245573d5d4fb6483a5e Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Wed, 23 Mar 2022 20:09:25 -0400 Subject: [PATCH 4/6] this should really be a MethodError --- stdlib/Random/src/Random.jl | 6 ++++-- stdlib/Random/test/runtests.jl | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/stdlib/Random/src/Random.jl b/stdlib/Random/src/Random.jl index e3cd5f7905787..4eb7a418734c9 100644 --- a/stdlib/Random/src/Random.jl +++ b/stdlib/Random/src/Random.jl @@ -143,8 +143,10 @@ Sampler(rng::AbstractRNG, ::Type{X}, r::Repetition=Val(Inf)) where {X} = typeof_rng(rng::AbstractRNG) = typeof(rng) -Sampler(::Type{<:AbstractRNG}, sp::Sampler, ::Repetition) = - throw(ArgumentError("Sampler for this object is not defined")) +# this method is necessary to prevent rand(rng::AbstractRNG, X) from +# recursively constructing nested Sampler types. +Sampler(T::Type{<:AbstractRNG}, sp::Sampler, r::Repetition) = + throw(MethodError(Sampler, (T, sp, r))) # default shortcut for the general case Sampler(::Type{RNG}, X) where {RNG<:AbstractRNG} = Sampler(RNG, X, Val(Inf)) diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index 6333ad43e1328..b91f7dde5b726 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -47,7 +47,7 @@ let A = zeros(2, 2) 0.9103565379264364 0.17732884646626457] end let A = zeros(2, 2) - @test_throws ArgumentError rand!(MersenneTwister(0), A, 5) + @test_throws MethodError rand!(MersenneTwister(0), A, 5) @test rand(MersenneTwister(0), Int64, 1) == [-3433174948434291912] end let A = zeros(Int64, 2, 2) @@ -434,7 +434,7 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()]) @test_throws MethodError r(Int32) @test_throws MethodError r(Bool) @test_throws MethodError r(String) - @test_throws ArgumentError r(AbstractFloat) + @test_throws MethodError r(AbstractFloat) @test_throws MethodError r(Int64, (2,3)) @test_throws MethodError r(String, 1) @@ -682,7 +682,7 @@ let b = ['0':'9';'A':'Z';'a':'z'] end # this shouldn't crash (#22403) -@test_throws ArgumentError rand!(Union{UInt,Int}[1, 2, 3]) +@test_throws MethodError rand!(Union{UInt,Int}[1, 2, 3]) @testset "$RNG() & Random.seed!(rng::$RNG) initializes randomly" for RNG in (MersenneTwister, RandomDevice, Xoshiro) m = RNG() @@ -754,8 +754,8 @@ end struct RandomStruct23964 end @testset "error message when rand not defined for a type" begin - @test_throws ArgumentError rand(nothing) - @test_throws ArgumentError rand(RandomStruct23964()) + @test_throws MethodError rand(nothing) + @test_throws MethodError rand(RandomStruct23964()) end @testset "rand(::$(typeof(RNG)), ::UnitRange{$T}" for RNG ∈ (MersenneTwister(rand(UInt128)), RandomDevice(), Xoshiro()), From 8682362fb0b7013201537bbbbd1af14db5a5bd5a Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 24 Mar 2022 08:13:31 -0400 Subject: [PATCH 5/6] log1p --- stdlib/Random/src/normal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/Random/src/normal.jl b/stdlib/Random/src/normal.jl index 5a7448ab0b3d9..9d0f1595f052f 100644 --- a/stdlib/Random/src/normal.jl +++ b/stdlib/Random/src/normal.jl @@ -148,7 +148,7 @@ end ### fallback randexp for float types defining rand: randexp(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat} = - -log(1-rand(rng, T)) + -log1p(-rand(rng, T)) ## arrays & other scalar methods From 70cbd69bc8d56666218a11e1a3dd1425b7e99e38 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 24 Mar 2022 08:13:55 -0400 Subject: [PATCH 6/6] log1p test --- stdlib/Random/test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index b91f7dde5b726..113e3d942977d 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -312,7 +312,7 @@ struct FakeFloat64 <: AbstractFloat x::Float64 end Base.rand(rng::AbstractRNG, ::Random.SamplerTrivial{Random.CloseOpen01{FakeFloat64}}) = FakeFloat64(rand(rng)) -for f in (:sqrt, :log, :one, :zero, :abs, :+, :-) +for f in (:sqrt, :log, :log1p, :one, :zero, :abs, :+, :-) @eval Base.$f(x::FakeFloat64) = FakeFloat64($f(x.x)) end for f in (:+, :-, :*, :/)