-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpgas.jl
121 lines (98 loc) · 3.5 KB
/
pgas.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
previous_state(trace::SSMTrace)
Return `Xₜ₋₁` or `nothing` from `model`
"""
function previous_state(trace::SSMTrace)
return trace.model.X[current_step(trace) - 1]
end
function past_idx(trace::SSMTrace)
return 1:(current_step(trace) - 1)
end
"""
current_step(model::AbstractStateSpaceModel)
Return current model step
"""
current_step(trace::SSMTrace) = trace.rng.count
"""
transition_logweight(model::AbstractStateSpaceModel, x)
Get the log weight of the transition from previous state of `model` to `x`
"""
function transition_logweight(particle::SSMTrace, x)
score = SSMProblems.transition_logdensity(
particle.model,
particle.model.X[current_step(particle) - 2],
x,
current_step(particle) - 1,
)
return score
end
"""
get_ancestor_logweights(pc::ParticleContainer{F,R}, x) where {F<:SSMTrace,R}
Get the ancestor log weights for each particle in `pc`
"""
function get_ancestor_logweights(pc::ParticleContainer{<:SSMTrace}, x, weights)
nparticles = length(pc.vals)
logweights = map(1:nparticles) do i
transition_logweight(pc.vals[i], x) + weights[i]
end
return logweights
end
"""
advance!(particle::SSMTrace, isref::Bool=false)
Return the log-probability of the transition nothing if done
"""
function advance!(particle::SSMTrace, isref::Bool=false)
isref ? load_state!(particle.rng) : save_state!(particle.rng)
model = particle.model
running_step = current_step(particle)
isdone(model, running_step) && return nothing
if !isref
if running_step == 1
new_state = SSMProblems.transition!!(particle.rng, model)
else
current_state = model.X[running_step - 1]
new_state = SSMProblems.transition!!(
particle.rng, model, current_state, running_step
)
end
else
new_state = model.X[running_step] # We need the current state from the reference particle
end
score = SSMProblems.emission_logdensity(model, new_state, running_step)
# accept transition
!isref && push!(model.X, new_state)
inc_counter!(particle.rng) # Increase rng counter, we use it as the model `time` index instead of handling two distinct counters
return score
end
function truncate!(particle::SSMTrace)
model = particle.model
idx = past_idx(particle)
model.X = model.X[idx]
particle.rng.keys = particle.rng.keys[idx]
return model
end
function fork(particle::SSMTrace, isref::Bool)
model = deepcopy(particle.model)
new_particle = Trace(model, deepcopy(particle.rng))
isref && truncate!(new_particle) # Forget the rest of the reference trajectory
return new_particle
end
function forkr(particle::SSMTrace)
Random123.set_counter!(particle.rng, 1)
newtrace = Trace(deepcopy(particle.model), deepcopy(particle.rng))
gen_refseed!(newtrace)
return newtrace
end
function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace}, sampler::PGAS)
current_step(ref) <= 2 && return nothing # At the beginning of step + 1 since we start at 1
isdone(ref.model, current_step(ref)) && return nothing
ancestor_weights = get_ancestor_logweights(
pc, ref.model.X[current_step(ref) - 1], pc.logWs
)
norm_weights = StatsFuns.softmax(ancestor_weights)
ancestor_index = rand(pc.rng, Distributions.Categorical(norm_weights))
ancestor = pc.vals[ancestor_index]
idx = past_idx(ref)
ref.model.X[idx] = ancestor.model.X[idx]
return ref.rng.keys[idx] = ancestor.rng.keys[idx]
end