Skip to content

Commit

Permalink
Merge pull request #16919 from JuliaLang/rf/MT-copy
Browse files Browse the repository at this point in the history
implement copy and == for MersenneTwister (fix #15698)
  • Loading branch information
JeffBezanson authored Jun 17, 2016
2 parents 76954a2 + 3207717 commit 2faed27
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 9 deletions.
12 changes: 10 additions & 2 deletions base/dSFMT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

module dSFMT

import Base: copy, copy!, ==

export DSFMT_state, dsfmt_get_min_array_size, dsfmt_get_idstring,
dsfmt_init_gen_rand, dsfmt_init_by_array, dsfmt_gv_init_by_array,
dsfmt_fill_array_close_open!, dsfmt_fill_array_close1_open2!,
Expand All @@ -21,10 +23,16 @@ const JPOLY1e21 = "e172e20c5d2de26b567c0cace9e7c6cc4407bd5ffcd22ca59d37b73d54fd

type DSFMT_state
val::Vector{Int32}
DSFMT_state() = new(Array{Int32}(JN32))
DSFMT_state(val::Vector{Int32}) = new(val)

DSFMT_state(val::Vector{Int32} = zeros(Int32, JN32)) =
new(length(val) == JN32 ? val : throw(DomainError()))
end

copy!(dst::DSFMT_state, src::DSFMT_state) = (copy!(dst.val, src.val); dst)
copy(src::DSFMT_state) = DSFMT_state(copy(src.val))

==(s1::DSFMT_state, s2::DSFMT_state) = s1.val == s2.val

function dsfmt_get_idstring()
idstring = ccall((:dsfmt_get_idstring,:libdSFMT),
Ptr{UInt8},
Expand Down
35 changes: 28 additions & 7 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Random

using Base.dSFMT
using Base.GMP: GMP_VERSION, Limb
import Base.copymutable
import Base: copymutable, copy, copy!, ==

export srand,
rand, rand!,
Expand Down Expand Up @@ -64,16 +64,37 @@ rand(rng::RandomDevice, ::Type{CloseOpen}) = rand(rng, Close1Open2) - 1.0
const MTCacheLength = dsfmt_get_min_array_size()

type MersenneTwister <: AbstractRNG
seed::Vector{UInt32}
state::DSFMT_state
vals::Vector{Float64}
idx::Int
seed::Vector{UInt32}

MersenneTwister(state::DSFMT_state, seed) = new(state, Array{Float64}(MTCacheLength), MTCacheLength, seed)
MersenneTwister(seed) = srand(new(DSFMT_state(), Array{Float64}(MTCacheLength)), seed)
MersenneTwister() = MersenneTwister(0)
function MersenneTwister(seed, state, vals, idx)
length(vals) == MTCacheLength && 0 <= idx <= MTCacheLength || throw(DomainError())
new(seed, state, vals, idx)
end
end

MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength)

MersenneTwister(seed=0) = srand(MersenneTwister(Vector{UInt32}(), DSFMT_state()), seed)

function copy!(dst::MersenneTwister, src::MersenneTwister)
copy!(resize!(dst.seed, length(src.seed)), src.seed)
copy!(dst.state, src.state)
copy!(dst.vals, src.vals)
dst.idx = src.idx
dst
end

copy(src::MersenneTwister) =
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), src.idx)

==(r1::MersenneTwister, r2::MersenneTwister) =
r1.seed == r2.seed && r1.state == r2.state && isequal(r1.vals, r2.vals) && r1.idx == r2.idx


## Low level API for MersenneTwister

@inline mt_avail(r::MersenneTwister) = MTCacheLength - r.idx
Expand Down Expand Up @@ -105,7 +126,7 @@ end
@inline rand_ui2x52_raw(r::MersenneTwister) = rand_ui52_raw(r) % UInt128 << 64 | rand_ui52_raw(r)

function srand(r::MersenneTwister, seed::Vector{UInt32})
r.seed = seed
copy!(resize!(r.seed, length(seed)), seed)
dsfmt_init_by_array(r.state, r.seed)
mt_setempty!(r)
return r
Expand All @@ -117,7 +138,7 @@ function randjump(mt::MersenneTwister, jumps::Integer, jumppoly::AbstractString)
push!(mts, mt)
for i in 1:jumps-1
cmt = mts[end]
push!(mts, MersenneTwister(dSFMT.dsfmt_jump(cmt.state, jumppoly), cmt.seed))
push!(mts, MersenneTwister(cmt.seed, dSFMT.dsfmt_jump(cmt.state, jumppoly)))
end
return mts
end
Expand Down
30 changes: 30 additions & 0 deletions test/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,33 @@ end
# test that the following is not an error (#16925)
srand(typemax(UInt))
srand(typemax(UInt128))

# copy and ==
let seed = rand(UInt32, 10)
r = MersenneTwister(seed)
@test r == MersenneTwister(seed) # r.vals should be all zeros
s = copy(r)
@test s == r && s !== r
skip, len = rand(0:2000, 2)
for j=1:skip
rand(r)
rand(s)
end
@test rand(r, len) == rand(s, len)
@test s == r
end

# MersenneTwister initialization with invalid values
@test_throws DomainError Base.dSFMT.DSFMT_state(zeros(Int32, rand(0:Base.dSFMT.JN32-1)))
@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(),
zeros(Float64, 10), 0)
@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(),
zeros(Float64, Base.Random.MTCacheLength), -1)

# seed is private to MersenneTwister
let seed = rand(UInt32, 10)
r = MersenneTwister(seed)
@test r.seed == seed && r.seed !== seed
resize!(seed, 4)
@test r.seed != seed
end

0 comments on commit 2faed27

Please sign in to comment.