Skip to content

Commit

Permalink
make Julia global RNG be an instance of MersenneTwister
Browse files Browse the repository at this point in the history
The state of the global RNG was previously handled by libdSFMT.
This commit allows the global RNG to benefit speed improvements resulting
from the use of fill_array_* functions (cf. previous commit).

However, scalar calls to rand() (i.e. producing one value instead of
filling an array) using the global RNG are now slower, as well as those
filling non-Float64 arrays (less than twice as slow though).
  • Loading branch information
rfourquet committed Oct 29, 2014
1 parent 0f2d3ea commit d005415
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ __init__() = srand()
## srand()

function srand(seed::Vector{Uint32})
global RANDOM_SEED = seed
dsfmt_gv_init_by_array(seed)
GLOBAL_RNG.seed = seed
dsfmt_init_by_array(GLOBAL_RNG.state, seed)
GLOBAL_RNG.idx = length(GLOBAL_RNG.vals)
return GLOBAL_RNG
end
srand(n::Integer) = srand(make_seed(n))

Expand All @@ -106,10 +108,15 @@ function srand(filename::String, n::Integer)
end
srand(filename::String) = srand(filename, 4)

## Global RNG

const GLOBAL_RNG = MersenneTwister()
globalRNG() = GLOBAL_RNG

## random floating point values

rand(::Type{Float64}) = dsfmt_gv_genrand_close_open()
rand() = dsfmt_gv_genrand_close_open()
rand() = rand(GLOBAL_RNG)
rand(::Type{Float64}) = rand()

rand(::Type{Float32}) = float32(rand())
rand(::Type{Float16}) = float16(rand())
Expand All @@ -118,13 +125,13 @@ rand{T<:Real}(::Type{Complex{T}}) = complex(rand(T),rand(T))

## random integers

dsfmt_randui32() = dsfmt_gv_genrand_uint32()
dsfmt_randui64() = uint64(dsfmt_randui32()) | (uint64(dsfmt_randui32())<<32)
# this is similar to `dsfmt_genrand_uint32` from dSFMT.h:
dsfmt_randui32(r::MersenneTwister) = reinterpret(Uint64, rand(r)) % Uint32

rand(::Type{Uint8}) = rand(Uint32) % Uint8
rand(::Type{Uint16}) = rand(Uint32) % Uint16
rand(::Type{Uint32}) = dsfmt_randui32()
rand(::Type{Uint64}) = dsfmt_randui64()
rand(::Type{Uint32}) = dsfmt_randui32(GLOBAL_RNG)
rand(::Type{Uint64}) = uint64(rand(Uint32)) <<32 | rand(Uint32)
rand(::Type{Uint128}) = uint128(rand(Uint64))<<64 | rand(Uint64)

rand(::Type{Int8}) = rand(Uint32) % Int8
Expand Down Expand Up @@ -194,6 +201,9 @@ function rand!(r::MersenneTwister, A::Array{Float64})
A
end

rand!(A::AbstractArray{Float64}) = rand!(GLOBAL_RNG, A)
rand!(A::Array{Float64}) = rand!(GLOBAL_RNG, A)

rand(T::Type, dims::Dims) = rand!(Array(T, dims))
rand{T<:Number}(::Type{T}) = error("no random number generator for type $T; try a more specific type")
rand{T<:Number}(::Type{T}, dims::Int...) = rand(T, dims)
Expand Down Expand Up @@ -293,7 +303,7 @@ rand!(B::BitArray) = Base.bitarray_rand_fill!(B)
randbool(dims::Dims) = rand!(BitArray(dims))
randbool(dims::Int...) = rand!(BitArray(dims))

randbool() = ((dsfmt_randui32() & 1) == 1)
randbool() = ((rand(Uint32) & 1) == 1)
rand(::Type{Bool}) = randbool()

## randn() - Normally distributed random numbers using Ziggurat algorithm
Expand Down Expand Up @@ -789,8 +799,7 @@ ziggurat_nor_r = 3.6541528853610087963519472518
ziggurat_nor_inv_r = inv(ziggurat_nor_r)
ziggurat_exp_r = 7.6971174701310497140446280481

randi() = reinterpret(Uint64,dsfmt_gv_genrand_close1_open2()) & 0x000fffffffffffff
@inline randi(rng::MersenneTwister) = reinterpret(Uint64, rand_close1_open2(rng)) & 0x000fffffffffffff
@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand_close1_open2(rng)) & 0x000fffffffffffff
for (lhs, rhs) in (([], []),
([:(rng::MersenneTwister)], [:rng]))
@eval begin
Expand Down

0 comments on commit d005415

Please sign in to comment.