Skip to content

Commit

Permalink
make Julia global RNG be an instance of MersenneTwister
Browse files Browse the repository at this point in the history
The state of the global RNG was previously handled by libdSFMT.
This commit allows the global RNG to benefit speed improvements resulting
from the use of fill_array_* functions (cf. previous commit).

However, individual calls to rand() (i.e. producing one value instead of
filling an array) using the global RNG are now slower, as well as those filling
non-Float64 arrays.
  • Loading branch information
rfourquet committed Oct 28, 2014
1 parent 727c872 commit 62e8a9c
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ __init__() = srand()
## srand()

function srand(seed::Vector{Uint32})
global RANDOM_SEED = seed
dsfmt_gv_init_by_array(seed)
GLOBAL_RNG.seed = seed
dsfmt_init_by_array(GLOBAL_RNG.state, seed)
GLOBAL_RNG.idx = length(GLOBAL_RNG.vals)
return GLOBAL_RNG
end
srand(n::Integer) = srand(make_seed(n))

Expand All @@ -94,10 +96,15 @@ function srand(filename::String, n::Integer)
end
srand(filename::String) = srand(filename, 4)

## Global RNG

GLOBAL_RNG = MersenneTwister()
globalRNG() = GLOBAL_RNG

## random floating point values

rand(::Type{Float64}) = dsfmt_gv_genrand_close_open()
rand() = dsfmt_gv_genrand_close_open()
rand() = rand(GLOBAL_RNG)
rand(::Type{Float64}) = rand()

rand(::Type{Float32}) = float32(rand())
rand(::Type{Float16}) = float16(rand())
Expand All @@ -112,13 +119,13 @@ rand{T<:Real}(::Type{Complex{T}}) = complex(rand(T),rand(T))

## random integers

dsfmt_randui32() = dsfmt_gv_genrand_uint32()
dsfmt_randui64() = uint64(dsfmt_randui32()) | (uint64(dsfmt_randui32())<<32)
# this is similar to `dsfmt_genrand_uint32` from dSFMT.h:
dsfmt_randui32(r::MersenneTwister) = reinterpret(Uint64, rand(r)) % Uint32

rand(::Type{Uint8}) = rand(Uint32) % Uint8
rand(::Type{Uint16}) = rand(Uint32) % Uint16
rand(::Type{Uint32}) = dsfmt_randui32()
rand(::Type{Uint64}) = dsfmt_randui64()
rand(::Type{Uint32}) = dsfmt_randui32(GLOBAL_RNG)
rand(::Type{Uint64}) = uint64(rand(Uint32)) <<32 | rand(Uint32)
rand(::Type{Uint128}) = uint128(rand(Uint64))<<64 | rand(Uint64)

rand(::Type{Int8}) = rand(Uint32) % Int8
Expand Down Expand Up @@ -188,6 +195,9 @@ function rand!(r::MersenneTwister, A::Array{Float64})
A
end

rand!(A::AbstractArray{Float64}) = rand!(GLOBAL_RNG, A)
rand!(A::Array{Float64}) = rand!(GLOBAL_RNG, A)

rand(T::Type, dims::Dims) = rand!(Array(T, dims))
rand{T<:Number}(::Type{T}) = error("no random number generator for type $T; try a more specific type")
rand{T<:Number}(::Type{T}, dims::Int...) = rand(T, dims)
Expand Down Expand Up @@ -783,8 +793,7 @@ ziggurat_nor_r = 3.6541528853610087963519472518
ziggurat_nor_inv_r = inv(ziggurat_nor_r)
ziggurat_exp_r = 7.6971174701310497140446280481

randi() = reinterpret(Uint64,dsfmt_gv_genrand_close1_open2()) & 0x000fffffffffffff
@inline randi(rng::MersenneTwister) = reinterpret(Uint64, rand_close1_open2(rng)) & 0x000fffffffffffff
@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand_close1_open2(rng)) & 0x000fffffffffffff
for (lhs, rhs) in (([], []),
([:(rng::MersenneTwister)], [:rng]))
@eval begin
Expand Down

0 comments on commit 62e8a9c

Please sign in to comment.