Skip to content

Commit

Permalink
Rel 0.4.0 - Multiple seeds if num_chains > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
goedman committed Jan 30, 2024
1 parent 4b42ea8 commit 9fb7d9e
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 20 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StanPathfinder"
uuid = "e8ee4b5e-54b2-4408-8575-c3c89e582a15"
authors = ["Rob J Goedman <goedman@mac.com>"]
version = "0.3.0"
version = "0.4.0"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Expand All @@ -11,6 +11,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StanBase = "d0ee94f6-a23d-54aa-bbe9-7f572d6da7f5"
StanIO = "a1b0710c-ff81-4c57-8075-167cfc590dd3"
Expand All @@ -23,6 +24,7 @@ DataFrames = "1"
DocStringExtensions = "0.9"
NamedTupleTools = "0.14"
Parameters = "0.12"
Primes = "0.5"
Reexport = "1.2"
StanBase = "4.7"
StanIO = "1"
Expand Down
17 changes: 15 additions & 2 deletions examples/Bernoulli/bernoulli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ data = Dict("N" => 10, "y" => [0, 1, 0, 1, 0, 0, 0, 0, 0, 1])
tmpdir = joinpath(@__DIR__, "tmp")

sm = PathfinderModel("bernoulli", bernoulli_model)
rc = stan_pathfinder(sm; data, seed=rand(1:200000000, 1)[1], num_chains=2)
rc = stan_pathfinder(sm; data)

if all(success.(rc))

Expand All @@ -39,4 +39,17 @@ if all(success.(rc))
end

sm2 = PathfinderModel("bernoulli2", bernoulli_model, tmpdir)
rc2 = stan_pathfinder(sm2; data, seed=rand(1:200000000, 1)[1], num_chains=2)
rc2 = stan_pathfinder(sm2; data, seed=rand(1:200000000, 2), num_chains=2)

if all(success.(rc2))

str = read(joinpath(sm2.tmpdir, "$(sm2.name)_log_1.log"), String)
findfirst("Path [1]", str)
str = split(str[findfirst("Path [1]", str)[1]:end], "\n")
display(str)

df = read_pathfinder(sm2)
profile_df = create_pathfinder_profile_df(sm2)
display(profile_df)

end
2 changes: 1 addition & 1 deletion src/StanPathfinder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Reexport

using CSV, DelimitedFiles, Unicode
using NamedTupleTools, Parameters
using DataFrames, Distributed
using DataFrames, Distributed, Primes

using DocStringExtensions: FIELDS, SIGNATURES, TYPEDEF

Expand Down
23 changes: 14 additions & 9 deletions src/stanmodel/PathfinderModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ mutable struct PathfinderModel <: CmdStanModels
num_elbo_draws::Int;

init::Int;
seed::Int;
seed::Vector{Int};
refresh::Int;
sig_figs::Int;
num_threads::Int;
Expand Down Expand Up @@ -93,10 +93,12 @@ function PathfinderModel(
throw(StanModelError(model, String(take!(error_output))))
end

num_chains = 1

PathfinderModel(name, model,
# Pathfinder default settings
# num_chains
1,
num_chains,
# init_alpha
0.001,
# tol_obj, tol_rel_obj
Expand All @@ -106,16 +108,19 @@ function PathfinderModel(
# tol_param
1e-8,
# history_size, num_psis_draws, num_paths
5, 1000, 4,
5, 2000, 4,
# psis_resample, calculate_lp, save_single_paths
true, true, false,
#max_lbfgs, num_draws, num_elbo_draws
1000, 1000, 25,

# init, seed, refresh, sig_figs, num_threads
2, 1995513073, 100, -1, 1,
1000, 2000, 25,
# init
2,
# seeds
rand(primes(10000001, 20000001), num_chains),
# refresh, sig_figs, num_threads
100, -1, 1,
# save_cmdstan_config
false,
true,

output_base, # Path to output files
tmpdir, # Tmpdir settings
Expand All @@ -140,6 +145,7 @@ function Base.show(io::IO, ::MIME"text/plain", m::PathfinderModel)
println(io, " refresh = ", m.refresh)
println(io, " sig_figs = ", m.sig_figs)
println(io, " num_threads = ", m.num_threads)
println(io, " save_cmdstan_config = ", m.save_cmdstan_config)

println(io, "\nPathfiner section:")
println(io, " init_alpha = ", m.init_alpha)
Expand All @@ -159,7 +165,6 @@ function Base.show(io::IO, ::MIME"text/plain", m::PathfinderModel)
println(io, " num_draws = ", m.num_draws)
println(io, " num_elbo_draws = ", m.num_elbo_draws)


println(io, "\nOther:")
println(io, " output_base = ", m.output_base)
println(io, " tmpdir = ", m.tmpdir)
Expand Down
2 changes: 1 addition & 1 deletion src/stanrun/cmdline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function cmdline(m::PathfinderModel, id)
cmd = `$cmd init=$(m.init)`
end

cmd = `$cmd random seed=$(m.seed)`
cmd = `$cmd random seed=$(m.seed[id])`

# Output options
cmd = `$cmd output`
Expand Down
28 changes: 22 additions & 6 deletions src/stanrun/stan_run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ See extended help for other keyword arguments ( `??stan_sample` ).
### Additional configuration keyword arguments
```julia
* `num_chains=1` # Update number of chains.
* `num_chains=4` # Number of chains.
* `init=2` # Bound for initial values.
* `seed=1995513073` # Set seed value.
* `refresh=100` # Strem to output.
* `sig_figs=-1` # Number of significant decimals used.
* `num_threads=1` # Number of threads.
* `init=2` # Bound for initial values.
* `seed=rand(primes(1, 20000001), num_chains)` # Array of seed values.
* `refresh=100` # Stream to output.
* `sig_figs=-1` # Number of significant decimals used.
* `num_threads=1` # Number of threads.
* `init_alpha=0.001`
* `tol_obj=9.99999999e-13`
Expand All @@ -56,6 +56,22 @@ function stan_run(m::PathfinderModel, use_json=true; kwargs...)

handle_keywords!(m, kwargs)

if :num_chains in keys(kwargs)
m.num_chains = kwargs[:num_chains]
m.seed = rand(primes(1, 20000001), m.num_chains)
end
if :seed in keys(kwargs)
if typeof(kwargs[:seed]) == Int
m.seed = repeat([kwargs[:seed]], m.num_chains)
else
if length(kwargs[:seed]) == m.num_chains
m.seed = kwargs[:seed]
else
m.seed = rand(primes(1, 20000001), m.num_chains)
end
end
end

if m.num_threads > 1
@info "Currently running StanPathfinder with num_threads>1 can lead to problematic results."
end
Expand Down

0 comments on commit 9fb7d9e

Please sign in to comment.