From 3ca78c1e408da46287989d3a9fee1e1e3666f560 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Wed, 23 Mar 2022 13:14:10 -0400 Subject: [PATCH] fallback randn/randexp for AbstractFloat --- NEWS.md | 2 ++ stdlib/Random/src/normal.jl | 13 +++++++++++++ stdlib/Random/test/runtests.jl | 32 ++++++++++++++++++++++++++------ 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/NEWS.md b/NEWS.md index 6e0b00c92f0416..2f8476abb2cc7a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -65,6 +65,8 @@ Standard library changes #### Random +* `randn` and `randexp` now work for any `AbstractFloat` type defining `rand` ([#44713]). + #### REPL #### SparseArrays diff --git a/stdlib/Random/src/normal.jl b/stdlib/Random/src/normal.jl index 6bb4cd2c36ce80..b143924da74e31 100644 --- a/stdlib/Random/src/normal.jl +++ b/stdlib/Random/src/normal.jl @@ -90,6 +90,16 @@ randn(rng::AbstractRNG, ::Type{Complex{T}}) where {T<:AbstractFloat} = Complex{T}(SQRT_HALF * randn(rng, T), SQRT_HALF * randn(rng, T)) +### fallback randn for float types defining rand: +function randn(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat} + # Marsaglia polar variant of Box–Muller transform: + while true + x, y = 2rand(rng, T)-1, 2rand(rng, T)-1 + 0 < (s = x^2 + y^2) < 1 || continue + return x * sqrt(-2log(s)/s) # and/or y * sqrt(...) + end +end + ## randexp """ @@ -137,6 +147,9 @@ end end end +### fallback randexp for float types defining rand: +randexp(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat} = + -log(rand(rng, T)) ## arrays & other scalar methods diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index c8be4c95cdaf2f..6333ad43e1328e 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -307,9 +307,32 @@ let a = [rand(RandomDevice(), UInt128) for i=1:10] @test reduce(|, a)>>>64 != 0 end +# wrapper around Float64 to check fallback random generators +struct FakeFloat64 <: AbstractFloat + x::Float64 +end +Base.rand(rng::AbstractRNG, ::Random.SamplerTrivial{Random.CloseOpen01{FakeFloat64}}) = FakeFloat64(rand(rng)) +for f in (:sqrt, :log, :one, :zero, :abs, :+, :-) + @eval Base.$f(x::FakeFloat64) = FakeFloat64($f(x.x)) +end +for f in (:+, :-, :*, :/) + @eval begin + Base.$f(x::FakeFloat64, y::FakeFloat64) = FakeFloat64($f(x.x,y.x)) + Base.$f(x::FakeFloat64, y::Real) = FakeFloat64($f(x.x,y)) + Base.$f(x::Real, y::FakeFloat64) = FakeFloat64($f(x,y.x)) + end +end +for f in (:<, :<=, :>, :>=, :(==), :(!=)) + @eval begin + Base.$f(x::FakeFloat64, y::FakeFloat64) = $f(x.x,y.x) + Base.$f(x::FakeFloat64, y::Real) = $f(x.x,y) + Base.$f(x::Real, y::FakeFloat64) = $f(x,y.x) + end +end + # test all rand APIs for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()]) - ftypes = [Float16, Float32, Float64] + ftypes = [Float16, Float32, Float64, FakeFloat64, BigFloat] cftypes = [ComplexF16, ComplexF32, ComplexF64, ftypes...] types = [Bool, Char, BigFloat, Base.BitInteger_types..., ftypes...] randset = Set(rand(Int, 20)) @@ -406,15 +429,12 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()]) rand!(rng..., BitMatrix(undef, 2, 3)) ::BitArray{2} # Test that you cannot call randn or randexp with non-Float types. - for r in [randn, randexp, randn!, randexp!] - local r + for r in [randn, randexp] @test_throws MethodError r(Int) @test_throws MethodError r(Int32) @test_throws MethodError r(Bool) @test_throws MethodError r(String) - @test_throws MethodError r(AbstractFloat) - # TODO(#17627): Consider adding support for randn(BigFloat) and removing this test. - @test_throws MethodError r(BigFloat) + @test_throws ArgumentError r(AbstractFloat) @test_throws MethodError r(Int64, (2,3)) @test_throws MethodError r(String, 1)