Skip to content

Commit

Permalink
Move delete_seeds to ext
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Jul 25, 2023
1 parent 18463e4 commit 85f9338
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
13 changes: 12 additions & 1 deletion ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
newtrace = copy(trace)
update_rng!(newtrace)
isref && AdvancedPS.delete_retained!(newtrace.model.f)
isref && AdvancedPS.delete_seeds!(newtrace)
isref && delete_seeds!(newtrace)

# add backward reference
addreference!(newtrace.model.ctask.task, newtrace)
Expand Down Expand Up @@ -204,4 +204,15 @@ function replay(particle::AdvancedPS.Particle)
return trace
end

"""
delete_seeds!(particle::Particle)
Truncate the seed history from the `particle` rng. When forking the reference Particle
we need to keep the seeds up to the current model iteration but generate new seeds
and random values afterward.
"""
function delete_seeds!(particle::AdvancedPS.Particle)
return particle.rng.keys = particle.rng.keys[1:(particle.rng.count - 1)]

Check warning on line 215 in ext/AdvancedPSLibtaskExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AdvancedPSLibtaskExt.jl#L214-L215

Added lines #L214 - L215 were not covered by tests
end

end
11 changes: 0 additions & 11 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,6 @@ Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng))

function observe end

"""
delete_seeds!(particle::Particle)
Truncate the seed history from the `particle` rng. When forking the reference Particle
we need to keep the seeds up to the current model iteration but generate new seeds
and random values afterward.
"""
function delete_seeds!(particle::Particle)
return particle.rng.keys = particle.rng.keys[1:(particle.rng.count - 1)]
end

"""
gen_refseed!(particle::Particle)
Expand Down
11 changes: 11 additions & 0 deletions test/pgas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,15 @@

@test vals1 vals2
end

@testset "smc sampler" begin
model = BaseModel(Params(0.9, 0.32, 1))
npart = 10

sampler = AdvancedPS.SMC(npart)
chains = sample(model, sampler)

@test length(chains.trajectories) == npart
@test length(chains.trajectories[1].model.X) == 3
end
end

0 comments on commit 85f9338

Please sign in to comment.