diff --git a/src/samplers/ipmcmc.jl b/src/samplers/ipmcmc.jl index e71cc96ff..78f9dae75 100644 --- a/src/samplers/ipmcmc.jl +++ b/src/samplers/ipmcmc.jl @@ -67,10 +67,23 @@ step(model::Function, spl::Sampler{IPMCMC}, VarInfos::Array{VarInfo}, is_first:: log_zs = zeros(spl.alg.n_nodes) # Run SMC & CSMC nodes - for j in 1:spl.alg.n_nodes - VarInfos[j].num_produce = 0 - VarInfos[j] = step(model, spl.info[:samplers][j], VarInfos[j]) - log_zs[j] = spl.info[:samplers][j].info[:logevidence][end] + if nprocs() > 1 + tmp_ = @parallel (vcat) for j in 1:spl.alg.n_nodes + VarInfos[j].num_produce = 0 + step(model, spl.info[:samplers][j], VarInfos[j]) + end + + for j in 1:spl.alg.n_nodes + VarInfos[j] = tmp_[j] + log_zs[j] = getlogp(tmp_[j]) + end + tmp_ = nothing + else + for j in 1:spl.alg.n_nodes + VarInfos[j].num_produce = 0 + VarInfos[j] = step(model, spl.info[:samplers][j], VarInfos[j]) + log_zs[j] = spl.info[:samplers][j].info[:logevidence][end] + end end # Resampling of CSMC nodes indices