Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dSFMT: use fill_array_* API instead of genrand_* API #8832

Merged
merged 2 commits into from
Oct 30, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions base/dSFMT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export DSFMT_state, dsfmt_get_min_array_size, dsfmt_get_idstring,
dsfmt_genrand_close1_open2, dsfmt_gv_genrand_close1_open2,
dsfmt_genrand_close_open, dsfmt_gv_genrand_close_open,
dsfmt_genrand_uint32, dsfmt_gv_genrand_uint32,
dsfmt_fill_array_close_open!, dsfmt_fill_array_close1_open2!,
win32_SystemFunction036!

type DSFMT_state
Expand Down Expand Up @@ -95,6 +96,24 @@ function dsfmt_gv_genrand_uint32()
())
end

# precondition for dsfmt_fill_array_*:
# the underlying C array must be 16-byte aligned, which is the case for "Array"
function dsfmt_fill_array_close1_open2!(s::DSFMT_state, A::Array{Float64}, n::Int)
@assert dsfmt_min_array_size <= n <= length(A) && iseven(n)
ccall((:dsfmt_fill_array_close1_open2,:libdSFMT),
Void,
(Ptr{Void}, Ptr{Float64}, Int),
s.val, A, n)
end

function dsfmt_fill_array_close_open!(s::DSFMT_state, A::Array{Float64}, n::Int)
@assert dsfmt_min_array_size <= n <= length(A) && iseven(n)
ccall((:dsfmt_fill_array_close_open,:libdSFMT),
Void,
(Ptr{Void}, Ptr{Float64}, Int),
s.val, A, n)
end

## Windows entropy

@windows_only begin
Expand Down
102 changes: 82 additions & 20 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,43 @@ abstract AbstractRNG
type MersenneTwister <: AbstractRNG
state::DSFMT_state
seed::Union(Uint32,Vector{Uint32})
vals::Vector{Float64}
idx::Int

function MersenneTwister(seed::Vector{Uint32})
state = DSFMT_state()
dsfmt_init_by_array(state, seed)
return new(state, seed)
return new(state, seed, Array(Float64, dsfmt_get_min_array_size()), dsfmt_get_min_array_size())
end

MersenneTwister(seed=0) = MersenneTwister(make_seed(seed))
end

## Low level API for MersenneTwister

function gen_rand(r::MersenneTwister)
dsfmt_fill_array_close1_open2!(r.state, r.vals, length(r.vals))
r.idx = 0
end

@inline gen_rand_maybe(r::MersenneTwister) = r.idx == length(r.vals) && gen_rand(r)

# precondition: r.idx < length(r.vals)
@inline rand_close1_open2_inbounds(r::MersenneTwister) = (r.idx += 1; @inbounds return r.vals[r.idx])
@inline rand_inbounds(r::MersenneTwister) = rand_close1_open2_inbounds(r) - 1.0

# produce Float64 values
@inline rand_close1_open2(r::MersenneTwister) = (gen_rand_maybe(r); rand_close1_open2_inbounds(r))
@inline rand_close_open(r::MersenneTwister) = (gen_rand_maybe(r); rand_inbounds(r))

# this is similar to `dsfmt_genrand_uint32` from dSFMT.h:
@inline rand_ui32(r::MersenneTwister) = reinterpret(Uint64, rand_close1_open2(r)) % Uint32


function srand(r::MersenneTwister, seed)
r.seed = seed
dsfmt_init_gen_rand(r.state, seed)
r.idx = length(r.vals)
return r
end

Expand Down Expand Up @@ -60,8 +84,10 @@ __init__() = srand()
## srand()

function srand(seed::Vector{Uint32})
global RANDOM_SEED = seed
dsfmt_gv_init_by_array(seed)
GLOBAL_RNG.seed = seed
dsfmt_init_by_array(GLOBAL_RNG.state, seed)
GLOBAL_RNG.idx = length(GLOBAL_RNG.vals)
return GLOBAL_RNG
end
srand(n::Integer) = srand(make_seed(n))

Expand All @@ -86,28 +112,28 @@ function srand(filename::String, n::Integer)
end
srand(filename::String) = srand(filename, 4)

## Global RNG

const GLOBAL_RNG = MersenneTwister()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should add a type annotation on GLOBAL_RNG too.

globalRNG() = GLOBAL_RNG

## random floating point values

rand(::Type{Float64}) = dsfmt_gv_genrand_close_open()
rand() = dsfmt_gv_genrand_close_open()
rand(r::MersenneTwister=GLOBAL_RNG) = rand_close_open(r)

rand(::Type{Float64}) = rand()

rand(::Type{Float32}) = float32(rand())
rand(::Type{Float16}) = float16(rand())

rand{T<:Real}(::Type{Complex{T}}) = complex(rand(T),rand(T))


rand(r::MersenneTwister) = dsfmt_genrand_close_open(r.state)

## random integers

dsfmt_randui32() = dsfmt_gv_genrand_uint32()
dsfmt_randui64() = uint64(dsfmt_randui32()) | (uint64(dsfmt_randui32())<<32)

rand(::Type{Uint8}) = rand(Uint32) % Uint8
rand(::Type{Uint16}) = rand(Uint32) % Uint16
rand(::Type{Uint32}) = dsfmt_randui32()
rand(::Type{Uint64}) = dsfmt_randui64()
rand(::Type{Uint32}) = rand_ui32(GLOBAL_RNG)
rand(::Type{Uint64}) = uint64(rand(Uint32)) <<32 | rand(Uint32)
rand(::Type{Uint128}) = uint128(rand(Uint64))<<64 | rand(Uint64)

rand(::Type{Int8}) = rand(Uint32) % Int8
Expand Down Expand Up @@ -142,6 +168,44 @@ function rand!{T}(r::AbstractRNG, A::AbstractArray{T})
A
end

function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float64})
n = length(A)
# what follows is equivalent to this simple loop but more efficient:
# for i=1:n
# @inbounds A[i] = rand(r)
# end
m = 0
while m < n
s = length(r.vals) - r.idx
if s == 0
gen_rand(r)
s = length(r.vals)
end
m2 = min(n, m+s)
for i=m+1:m2
@inbounds A[i] = rand_inbounds(r)
end
m = m2
end
A
end

rand!(r::MersenneTwister, A::AbstractArray{Float64}) = rand_AbstractArray_Float64!(r, A)

function rand!(r::MersenneTwister, A::Array{Float64})
n = length(A)
if n < dsfmt_get_min_array_size()
rand_AbstractArray_Float64!(r, A)
else
dsfmt_fill_array_close_open!(r.state, A, 2*(n ÷ 2))
isodd(n) && (A[n] = rand(r))
end
A
end

rand!(A::AbstractArray{Float64}) = rand!(GLOBAL_RNG, A)
rand!(A::Array{Float64}) = rand!(GLOBAL_RNG, A)

rand(T::Type, dims::Dims) = rand!(Array(T, dims))
rand{T<:Number}(::Type{T}) = error("no random number generator for type $T; try a more specific type")
rand{T<:Number}(::Type{T}, dims::Int...) = rand(T, dims)
Expand Down Expand Up @@ -241,7 +305,7 @@ rand!(B::BitArray) = Base.bitarray_rand_fill!(B)
randbool(dims::Dims) = rand!(BitArray(dims))
randbool(dims::Int...) = rand!(BitArray(dims))

randbool() = ((dsfmt_randui32() & 1) == 1)
randbool() = ((rand(Uint32) & 1) == 1)
rand(::Type{Bool}) = randbool()

## randn() - Normally distributed random numbers using Ziggurat algorithm
Expand Down Expand Up @@ -737,11 +801,9 @@ ziggurat_nor_r = 3.6541528853610087963519472518
ziggurat_nor_inv_r = inv(ziggurat_nor_r)
ziggurat_exp_r = 7.6971174701310497140446280481

rand(state::DSFMT_state) = dsfmt_genrand_close_open(state)
randi() = reinterpret(Uint64,dsfmt_gv_genrand_close1_open2()) & 0x000fffffffffffff
randi(state::DSFMT_state) = reinterpret(Uint64,dsfmt_genrand_close1_open2(state)) & 0x000fffffffffffff
@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand_close1_open2(rng)) & 0x000fffffffffffff
for (lhs, rhs) in (([], []),
([:(state::DSFMT_state)], [:state]))
([:(rng::MersenneTwister)], [:rng]))
@eval begin
function randmtzig_randn($(lhs...))
@inbounds begin
Expand Down Expand Up @@ -787,9 +849,9 @@ for (lhs, rhs) in (([], []),
end

randn() = randmtzig_randn()
randn(rng::MersenneTwister) = randmtzig_randn(rng.state)
randn(rng::MersenneTwister) = randmtzig_randn(rng)
randn!(A::Array{Float64}) = (for i = 1:length(A);A[i] = randmtzig_randn();end;A)
randn!(rng::MersenneTwister, A::Array{Float64}) = (for i = 1:length(A);A[i] = randmtzig_randn(rng.state);end;A)
randn!(rng::MersenneTwister, A::Array{Float64}) = (for i = 1:length(A);A[i] = randmtzig_randn(rng);end;A)
randn(dims::Dims) = randn!(Array(Float64, dims))
randn(dims::Int...) = randn!(Array(Float64, dims...))
randn(rng::MersenneTwister, dims::Dims) = randn!(rng, Array(Float64, dims))
Expand Down