diff --git a/examples/BernoulliOptimize/bernoulli_optimize.jl b/examples/BernoulliOptimize/bernoulli_optimize.jl index 3bfbed9..4434f9f 100644 --- a/examples/BernoulliOptimize/bernoulli_optimize.jl +++ b/examples/BernoulliOptimize/bernoulli_optimize.jl @@ -30,7 +30,7 @@ cd(ProjDir) #do println() display(optim) println() - println("Test round.(mean(optim[1][\"theta\"]), digits=1) ≈ 0.3") + println("Test round.(mean(optim[\"theta\"]), digits=1) ≈ 0.3") @test round.(mean(optim["theta"]), digits=1) ≈ 0.3 end diff --git a/src/main/stancode.jl b/src/main/stancode.jl index 49bbc33..abdce2f 100644 --- a/src/main/stancode.jl +++ b/src/main/stancode.jl @@ -96,11 +96,12 @@ function stan( @assert isdir(ProjDir) "Incorrect ProjDir specified: $(ProjDir)" @assert isdir(model.tmpdir) "$(model.tmpdir) not created" - cd(model.tmpdir) do - isfile("$(model.name)_build.log") && rm("$(model.name)_build.log") - isfile("$(model.name)_make.log") && rm("$(model.name)_make.log") - isfile("$(model.name)_run.log") && rm("$(model.name)_run.log") - end + absolute_tempdir_path = cd(pwd,model.tmpdir) + path_prefix = joinpath(splitpath(absolute_tempdir_path)..., model.name) + model.object_file = path_prefix + isfile("$(path_prefix)_build.log") && rm("$(path_prefix)_build.log") + isfile("$(path_prefix)_make.log") && rm("$(path_prefix)_make.log") + isfile("$(path_prefix)_run.log") && rm("$(path_prefix)_run.log") cd(CmdStanDir) do local tmpmodelname::String @@ -126,138 +127,136 @@ function stan( end end - cd(model.tmpdir) do - if data != Nothing && (typeof(data) <: AbstractString || check_dct_type(data)) - if typeof(data) <: AbstractString + if data != Nothing && (typeof(data) <: AbstractString || check_dct_type(data)) + if typeof(data) <: AbstractString + for i in 1:model.nchains + cp(data, "$(path_prefix)_$(i).data.R", force=true) + end + else + if typeof(data) <: Array && length(data) == model.nchains for i in 1:model.nchains - cp(data, "$(model.name)_$(i).data.R", force=true) - end + if length(keys(data[i])) > 0 + update_R_file("$(path_prefix)_$(i).data.R", data[i]) + end + end else - if typeof(data) <: Array && length(data) == model.nchains + if typeof(data) <: Array for i in 1:model.nchains - if length(keys(data[i])) > 0 - update_R_file("$(model.name)_$(i).data.R", data[i]) + if length(keys(data[1])) > 0 + update_R_file("$(path_prefix)_$(i).data.R", data[1]) end end else - if typeof(data) <: Array - for i in 1:model.nchains - if length(keys(data[1])) > 0 - update_R_file("$(model.name)_$(i).data.R", data[1]) - end - end - else - for i in 1:model.nchains - if length(keys(data)) > 0 - update_R_file("$(model.name)_$(i).data.R", data) - end + for i in 1:model.nchains + if length(keys(data)) > 0 + update_R_file("$(path_prefix)_$(i).data.R", data) end end end end end - - if init != Nothing && (typeof(init) <: AbstractString || check_dct_type(init)) - if typeof(init) <: AbstractString + end + + if init != Nothing && (typeof(init) <: AbstractString || check_dct_type(init)) + if typeof(init) <: AbstractString + for i in 1:model.nchains + cp(init, "$(path_prefix)_$(i).init.R", force=true) + end + else + if typeof(init) <: Array && length(init) == model.nchains for i in 1:model.nchains - cp(init, "$(model.name)_$(i).init.R", force=true) - end + if length(keys(init[i])) > 0 + update_R_file("$(path_prefix)_$(i).init.R", init[i]) + end + end else - if typeof(init) <: Array && length(init) == model.nchains + if typeof(init) <: Array for i in 1:model.nchains - if length(keys(init[i])) > 0 - update_R_file("$(model.name)_$(i).init.R", init[i]) + if length(keys(init[1])) > 0 + update_R_file("$(path_prefix)_$(i).init.R", init[1]) end end else - if typeof(init) <: Array - for i in 1:model.nchains - if length(keys(init[1])) > 0 - update_R_file("$(model.name)_$(i).init.R", init[1]) - end - end - else - for i in 1:model.nchains - if length(keys(init)) > 0 - update_R_file("$(model.name)_$(i).init.R", init) - end + for i in 1:model.nchains + if length(keys(init)) > 0 + update_R_file("$(path_prefix)_$(i).init.R", init) end end end end end + end + + for i in 1:model.nchains + model.id = i + model.data_file ="$(path_prefix)_$(i).data.R" + if init != Nothing + model.init_file = "$(path_prefix)_$(i).init.R" + end + if isa(model.method, Sample) + model.output.file = path_prefix*"_samples_$(i).csv" + isfile(model.output.file) && rm(model.output.file) + if diagnostics + model.output.diagnostic_file = path_prefix*"_diagnostics_$(i).csv" + isfile(model.output.diagnostic_file) && rm(model.output.diagnostic_file) + end + elseif isa(model.method, Optimize) + model.output.file = path_prefix*"_optimize_$(i).csv" + isfile(model.output.file) && rm(model.output.file) + elseif isa(model.method, Variational) + model.output.file = path_prefix*"_variational_$(i).csv" + isfile(model.output.file) && rm(model.output.file) + elseif isa(model.method, Diagnose) + model.output.file = path_prefix*"_diagnose_$(i).csv" + isfile(model.output.file) && rm(model.output.file) + end + model.command[i] = cmdline(model) + end + + try + if file_run_log + run(pipeline(par(model.command), stdout="$(path_prefix)_run.log")) + else + run(par(model.command)) + end + catch e + println("\nAn error occurred while running the previously compiled Stan program.\n") + print("Please check the contents of file $(model.name)_run.log and the") + println("'command' field in the Stanmodel, e.g. stanmodel.command.\n") + error("Return code = -5") + end + + local samplefiles = String[] + local ftype + # local cnames = String[] + + if typeof(model.method) in [Sample, Variational] + if isa(model.method, Sample) + ftype = diagnostics ? "diagnostics" : "samples" + else + ftype = lowercase(string(typeof(model.method))) + end for i in 1:model.nchains - model.id = i - model.data_file ="$(model.name)_$(i).data.R" - if init != Nothing - model.init_file = "$(model.name)_$(i).init.R" - end - if isa(model.method, Sample) - model.output.file = model.name*"_samples_$(i).csv" - isfile("$(model.name)_samples_$(i).csv") && rm("$(model.name)_samples_$(i).csv") - if diagnostics - model.output.diagnostic_file = model.name*"_diagnostics_$(i).csv" - isfile("$(model.name)_diagnostics_$(i).csv") && rm("$(model.name)_diagnostics_$(i).csv") - end - elseif isa(model.method, Optimize) - isfile("$(model.name)_optimize_$(i).csv") && rm("$(model.name)_optimize_$(i).csv") - model.output.file = model.name*"_optimize_$(i).csv" - elseif isa(model.method, Variational) - isfile("$(model.name)_variational_$(i).csv") && rm("$(model.name)_variational_$(i).csv") - model.output.file = model.name*"_variational_$(i).csv" - elseif isa(model.method, Diagnose) - isfile("$(model.name)_diagnose_$(i).csv") && rm("$(model.name)_diagnose_$(i).csv") - model.output.file = model.name*"_diagnose_$(i).csv" - end - model.command[i] = cmdline(model) + push!(samplefiles, "$(path_prefix)_$(ftype)_$(i).csv") end - try - if file_run_log - run(pipeline(par(model.command), stdout="$(model.name)_run.log")) - else - run(par(model.command)) - end - catch e - println("\nAn error occurred while running the previously compiled Stan program.\n") - print("Please check the contents of file $(model.name)_run.log and the") - println("'command' field in the Stanmodel, e.g. stanmodel.command.\n") - error("Return code = -5") + if summary + stan_summary(model, par(samplefiles), CmdStanDir=CmdStanDir) end - local samplefiles = String[] - local ftype - # local cnames = String[] + (res, cnames) = read_samples(model, diagnostics) - if typeof(model.method) in [Sample, Variational] - if isa(model.method, Sample) - ftype = diagnostics ? "diagnostics" : "samples" - else - ftype = lowercase(string(typeof(model.method))) - end - - for i in 1:model.nchains - push!(samplefiles, "$(model.name)_$(ftype)_$(i).csv") - end - - if summary - stan_summary(model, par(samplefiles), CmdStanDir=CmdStanDir) - end - - (res, cnames) = read_samples(model, diagnostics) - - elseif isa(model.method, Optimize) - res, cnames = read_optimize(model) + elseif isa(model.method, Optimize) + res, cnames = read_optimize(model) - elseif isa(model.method, Diagnose) - res, cnames = read_diagnose(model) + elseif isa(model.method, Diagnose) + res, cnames = read_diagnose(model) - else - println("\nAn unknown method is specified in the call to stan().") - error("Return code = -10") - end - end # cd() + else + println("\nAn unknown method is specified in the call to stan().") + error("Return code = -10") + end if model.output_format != :array start_sample = 1 diff --git a/src/main/stanmodel.jl b/src/main/stanmodel.jl index 0ebf6d0..06dadd0 100644 --- a/src/main/stanmodel.jl +++ b/src/main/stanmodel.jl @@ -41,6 +41,7 @@ mutable struct Stanmodel monitors::Vector{String} data_file::String command::Vector{Base.AbstractCmd} + object_file::String method::Method random::Random init_file::String @@ -155,6 +156,7 @@ function Stanmodel( id::Int=0 data_file::String="" + object_file::String="" init_file::String="" cmdarray = fill(``, nchains) @@ -173,7 +175,7 @@ function Stanmodel( Stanmodel(name, nchains, num_warmup, num_samples, thin, id, model, model_file, monitors, - data_file, cmdarray, method, random, + data_file, cmdarray, object_file, method, random, init_file, output, printsummary, pdir, tmpdir, output_format); end diff --git a/src/utilities/create_cmd_line.jl b/src/utilities/create_cmd_line.jl index 7fc8425..61849c7 100644 --- a/src/utilities/create_cmd_line.jl +++ b/src/utilities/create_cmd_line.jl @@ -23,7 +23,7 @@ function cmdline(m) cmd = `` if isa(m, Stanmodel) # Handle the model name field for unix and windows - cmd = @static Sys.isunix() ? `./$(getfield(m, :name))` : `cmd /c $(getfield(m, :name)).exe` + cmd = @static Sys.isunix() ? `$(getfield(m, :object_file))` : `cmd /c $(getfield(m, :object_file)).exe` # Method (sample, optimize, variational and diagnose) specific portion of the model cmd = `$cmd $(cmdline(getfield(m, :method)))` diff --git a/src/utilities/read_diagnose.jl b/src/utilities/read_diagnose.jl index 0dbd7cd..6cc7d3b 100644 --- a/src/utilities/read_diagnose.jl +++ b/src/utilities/read_diagnose.jl @@ -27,7 +27,7 @@ function read_diagnose(model::Stanmodel) local sstr for i in 1:model.nchains - if isfile("$(model.name)_$(res_type)_$(i).csv") + if isfile("$(model.object_file)_$(res_type)_$(i).csv") ## A result type file for chain i is present ## @@ -35,7 +35,7 @@ function read_diagnose(model::Stanmodel) # Extract cmdstan version - str = read("$(model.name)_$(res_type)_$(i).csv", String) + str = read("$(model.object_file)_$(res_type)_$(i).csv", String) sstr = split(str) tdict[:stan_version] = "$(parse(Int, sstr[4])).$(parse(Int, sstr[8])).$(parse(Int, sstr[12]))" end diff --git a/src/utilities/read_optimize.jl b/src/utilities/read_optimize.jl index bec273c..39aca94 100644 --- a/src/utilities/read_optimize.jl +++ b/src/utilities/read_optimize.jl @@ -28,12 +28,12 @@ function read_optimize(model::Stanmodel) tdict = Dict() for i in 1:model.nchains - if isfile("$(model.name)_$(res_type)_$(i).csv") + if isfile("$(model.object_file)_$(res_type)_$(i).csv") # A result type file for chain i is present ## - instream = open("$(model.name)_$(res_type)_$(i).csv") + instream = open("$(model.object_file)_$(res_type)_$(i).csv") if i == 1 - open("$(model.name)_$(res_type)_$(i).csv") do instream + open("$(model.object_file)_$(res_type)_$(i).csv") do instream str = read(instream, String) sstr = split(str) tdict[:stan_major_version] = [parse(Int, sstr[4])] @@ -43,7 +43,7 @@ function read_optimize(model::Stanmodel) end # After reopening the file, skip all comment lines - open("$(model.name)_$(res_type)_$(i).csv") do instream + open("$(model.object_file)_$(res_type)_$(i).csv") do instream skipchars(isspace, instream, linecomment='#') line = Unicode.normalize(readline(instream), newline2lf=true) diff --git a/src/utilities/read_samples.jl b/src/utilities/read_samples.jl index ca3b1eb..0275104 100644 --- a/src/utilities/read_samples.jl +++ b/src/utilities/read_samples.jl @@ -50,11 +50,13 @@ function read_samples(m::Stanmodel, diagnostics=false, warmup_samples=false) end # Read .csv files created by each chain - + + cnames = String[] + a3d = Dict{String,Any}("a" => nothing) for i in 1:m.nchains - if isfile("$(m.name)_$(ftype)_$(i).csv") + if isfile("$(m.object_file)_$(ftype)_$(i).csv") - instream = open("$(m.name)_$(ftype)_$(i).csv") + instream = open("$(m.object_file)_$(ftype)_$(i).csv") # Skip initial set of commented lines, e.g. containing # cmdstan version info, etc. @@ -78,8 +80,8 @@ function read_samples(m::Stanmodel, diagnostics=false, warmup_samples=false) # Preserve cnames and create a3d if i == 1 - global cnames = convert.(String, idx[indvec]) - global a3d = fill(0.0, noofsamples, length(indvec), m.nchains) + append!(cnames, convert.(String, idx[indvec])) + a3d["a"] = fill(0.0, noofsamples, length(indvec), m.nchains) end # Read in the samples for all chains @@ -93,12 +95,12 @@ function read_samples(m::Stanmodel, diagnostics=false, warmup_samples=false) else flds = parse.(Float64, split(strip(line), ",")) flds = reshape(flds[indvec], 1, length(indvec)) - a3d[j,:,i] = flds + a3d["a"][j,:,i] = flds end end # read in samples end # select next file if it exists end # loop over all chains - (a3d, cnames) + (a3d["a"], cnames) end # end of read_samples diff --git a/src/utilities/read_summary.jl b/src/utilities/read_summary.jl index 2bcd50d..c1e6158 100644 --- a/src/utilities/read_summary.jl +++ b/src/utilities/read_summary.jl @@ -19,7 +19,7 @@ read_summary(m::Stanmodel) """ function read_summary(m::Stanmodel) - fname = "$(m.tmpdir)/$(m.name)_summary.csv" + fname = "$(m.object_file)_summary.csv" !isfile(fname) && stan_summary(m) df = CSV.read(fname, DataFrame; delim=",", comment="#") diff --git a/src/utilities/read_variational.jl b/src/utilities/read_variational.jl index a674e28..ff87324 100644 --- a/src/utilities/read_variational.jl +++ b/src/utilities/read_variational.jl @@ -24,8 +24,8 @@ function read_variational(m::Stanmodel) ftype = lowercase(string(typeof(m.method))) for i in 1:m.nchains - if isfile("$(m.name)_$(ftype)_$(i).csv") - open("$(m.name)_$(ftype)_$(i).csv") do instream + if isfile("$(m.object_file)_$(ftype)_$(i).csv") + open("$(m.object_file)_$(ftype)_$(i).csv") do instream skipchars(isspace, instream, linecomment='#') line = Unicode.normalize(readline(instream), newline2lf=true) idx = split(strip(line), ",") diff --git a/src/utilities/stan_summary.jl b/src/utilities/stan_summary.jl index ed14b5b..a211cca 100644 --- a/src/utilities/stan_summary.jl +++ b/src/utilities/stan_summary.jl @@ -32,7 +32,7 @@ function stan_summary(model::Stanmodel, file::String; CmdStanDir=CMDSTAN_HOME) try pstring = joinpath("$(CmdStanDir)", "bin", "stansummary") - csvfile = "$(model.name)_summary.csv" + csvfile = "$(model.object_file)_summary.csv" isfile(csvfile) && rm(csvfile) cmd = `$(pstring) --csv_file=$(csvfile) $(file)` resfile = open(cmd; read=true) @@ -77,7 +77,7 @@ function stan_summary(model::Stanmodel, filecmd::Cmd; CmdStanDir=CMDSTAN_HOME) try pstring = joinpath("$(CmdStanDir)", "bin", "stansummary") - csvfile = "$(model.name)_summary.csv" + csvfile = "$(model.object_file)_summary.csv" isfile(csvfile) && rm(csvfile) cmd = `$(pstring) --csv_file=$(csvfile) $(filecmd)` if model.printsummary