diff --git a/base/random.jl b/base/random.jl index 1356f008f3b77f..bdbba15f68a799 100644 --- a/base/random.jl +++ b/base/random.jl @@ -68,8 +68,10 @@ __init__() = srand() ## srand() function srand(seed::Vector{Uint32}) - global RANDOM_SEED = seed - dsfmt_gv_init_by_array(seed) + GLOBAL_RNG.seed = seed + dsfmt_init_by_array(GLOBAL_RNG.state, seed) + GLOBAL_RNG.idx = length(GLOBAL_RNG.vals) + return GLOBAL_RNG end srand(n::Integer) = srand(make_seed(n)) @@ -94,10 +96,15 @@ function srand(filename::String, n::Integer) end srand(filename::String) = srand(filename, 4) +## Global RNG + +GLOBAL_RNG = MersenneTwister() +globalRNG() = GLOBAL_RNG + ## random floating point values -rand(::Type{Float64}) = dsfmt_gv_genrand_close_open() -rand() = dsfmt_gv_genrand_close_open() +rand() = rand(GLOBAL_RNG) +rand(::Type{Float64}) = rand() rand(::Type{Float32}) = float32(rand()) rand(::Type{Float16}) = float16(rand()) @@ -112,13 +119,13 @@ rand{T<:Real}(::Type{Complex{T}}) = complex(rand(T),rand(T)) ## random integers -dsfmt_randui32() = dsfmt_gv_genrand_uint32() -dsfmt_randui64() = uint64(dsfmt_randui32()) | (uint64(dsfmt_randui32())<<32) +# this is similar to `dsfmt_genrand_uint32` from dSFMT.h: +dsfmt_randui32(r::MersenneTwister) = reinterpret(Uint64, rand(r)) % Uint32 rand(::Type{Uint8}) = rand(Uint32) % Uint8 rand(::Type{Uint16}) = rand(Uint32) % Uint16 -rand(::Type{Uint32}) = dsfmt_randui32() -rand(::Type{Uint64}) = dsfmt_randui64() +rand(::Type{Uint32}) = dsfmt_randui32(GLOBAL_RNG) +rand(::Type{Uint64}) = uint64(rand(Uint32)) <<32 | rand(Uint32) rand(::Type{Uint128}) = uint128(rand(Uint64))<<64 | rand(Uint64) rand(::Type{Int8}) = rand(Uint32) % Int8 @@ -188,6 +195,9 @@ function rand!(r::MersenneTwister, A::Array{Float64}) A end +rand!(A::AbstractArray{Float64}) = rand!(GLOBAL_RNG, A) +rand!(A::Array{Float64}) = rand!(GLOBAL_RNG, A) + rand(T::Type, dims::Dims) = rand!(Array(T, dims)) rand{T<:Number}(::Type{T}) = error("no random number generator for type $T; try a more specific type") rand{T<:Number}(::Type{T}, dims::Int...) = rand(T, dims) @@ -783,8 +793,7 @@ ziggurat_nor_r = 3.6541528853610087963519472518 ziggurat_nor_inv_r = inv(ziggurat_nor_r) ziggurat_exp_r = 7.6971174701310497140446280481 -randi() = reinterpret(Uint64,dsfmt_gv_genrand_close1_open2()) & 0x000fffffffffffff -@inline randi(rng::MersenneTwister) = reinterpret(Uint64, rand_close1_open2(rng)) & 0x000fffffffffffff +@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand_close1_open2(rng)) & 0x000fffffffffffff for (lhs, rhs) in (([], []), ([:(rng::MersenneTwister)], [:rng])) @eval begin