Skip to content
This repository has been archived by the owner on Nov 18, 2022. It is now read-only.

Commit

Permalink
Merge pull request #100 from Byrth/enhancement/reduce_cd_dependence
Browse files Browse the repository at this point in the history
Switch to absolute path
  • Loading branch information
goedman authored Aug 22, 2020
2 parents 98d5754 + 8212224 commit 3b6d5ad
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 127 deletions.
2 changes: 1 addition & 1 deletion examples/BernoulliOptimize/bernoulli_optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ cd(ProjDir) #do
println()
display(optim)
println()
println("Test round.(mean(optim[1][\"theta\"]), digits=1) ≈ 0.3")
println("Test round.(mean(optim[\"theta\"]), digits=1) ≈ 0.3")
@test round.(mean(optim["theta"]), digits=1) 0.3
end

Expand Down
211 changes: 105 additions & 106 deletions src/main/stancode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ function stan(
@assert isdir(ProjDir) "Incorrect ProjDir specified: $(ProjDir)"
@assert isdir(model.tmpdir) "$(model.tmpdir) not created"

cd(model.tmpdir) do
isfile("$(model.name)_build.log") && rm("$(model.name)_build.log")
isfile("$(model.name)_make.log") && rm("$(model.name)_make.log")
isfile("$(model.name)_run.log") && rm("$(model.name)_run.log")
end
absolute_tempdir_path = cd(pwd,model.tmpdir)
path_prefix = joinpath(splitpath(absolute_tempdir_path)..., model.name)
model.object_file = path_prefix
isfile("$(path_prefix)_build.log") && rm("$(path_prefix)_build.log")
isfile("$(path_prefix)_make.log") && rm("$(path_prefix)_make.log")
isfile("$(path_prefix)_run.log") && rm("$(path_prefix)_run.log")

cd(CmdStanDir) do
local tmpmodelname::String
Expand All @@ -126,138 +127,136 @@ function stan(
end
end

cd(model.tmpdir) do
if data != Nothing && (typeof(data) <: AbstractString || check_dct_type(data))
if typeof(data) <: AbstractString
if data != Nothing && (typeof(data) <: AbstractString || check_dct_type(data))
if typeof(data) <: AbstractString
for i in 1:model.nchains
cp(data, "$(path_prefix)_$(i).data.R", force=true)
end
else
if typeof(data) <: Array && length(data) == model.nchains
for i in 1:model.nchains
cp(data, "$(model.name)_$(i).data.R", force=true)
end
if length(keys(data[i])) > 0
update_R_file("$(path_prefix)_$(i).data.R", data[i])
end
end
else
if typeof(data) <: Array && length(data) == model.nchains
if typeof(data) <: Array
for i in 1:model.nchains
if length(keys(data[i])) > 0
update_R_file("$(model.name)_$(i).data.R", data[i])
if length(keys(data[1])) > 0
update_R_file("$(path_prefix)_$(i).data.R", data[1])
end
end
else
if typeof(data) <: Array
for i in 1:model.nchains
if length(keys(data[1])) > 0
update_R_file("$(model.name)_$(i).data.R", data[1])
end
end
else
for i in 1:model.nchains
if length(keys(data)) > 0
update_R_file("$(model.name)_$(i).data.R", data)
end
for i in 1:model.nchains
if length(keys(data)) > 0
update_R_file("$(path_prefix)_$(i).data.R", data)
end
end
end
end
end

if init != Nothing && (typeof(init) <: AbstractString || check_dct_type(init))
if typeof(init) <: AbstractString
end

if init != Nothing && (typeof(init) <: AbstractString || check_dct_type(init))
if typeof(init) <: AbstractString
for i in 1:model.nchains
cp(init, "$(path_prefix)_$(i).init.R", force=true)
end
else
if typeof(init) <: Array && length(init) == model.nchains
for i in 1:model.nchains
cp(init, "$(model.name)_$(i).init.R", force=true)
end
if length(keys(init[i])) > 0
update_R_file("$(path_prefix)_$(i).init.R", init[i])
end
end
else
if typeof(init) <: Array && length(init) == model.nchains
if typeof(init) <: Array
for i in 1:model.nchains
if length(keys(init[i])) > 0
update_R_file("$(model.name)_$(i).init.R", init[i])
if length(keys(init[1])) > 0
update_R_file("$(path_prefix)_$(i).init.R", init[1])
end
end
else
if typeof(init) <: Array
for i in 1:model.nchains
if length(keys(init[1])) > 0
update_R_file("$(model.name)_$(i).init.R", init[1])
end
end
else
for i in 1:model.nchains
if length(keys(init)) > 0
update_R_file("$(model.name)_$(i).init.R", init)
end
for i in 1:model.nchains
if length(keys(init)) > 0
update_R_file("$(path_prefix)_$(i).init.R", init)
end
end
end
end
end
end

for i in 1:model.nchains
model.id = i
model.data_file ="$(path_prefix)_$(i).data.R"
if init != Nothing
model.init_file = "$(path_prefix)_$(i).init.R"
end
if isa(model.method, Sample)
model.output.file = path_prefix*"_samples_$(i).csv"
isfile(model.output.file) && rm(model.output.file)
if diagnostics
model.output.diagnostic_file = path_prefix*"_diagnostics_$(i).csv"
isfile(model.output.diagnostic_file) && rm(model.output.diagnostic_file)
end
elseif isa(model.method, Optimize)
model.output.file = path_prefix*"_optimize_$(i).csv"
isfile(model.output.file) && rm(model.output.file)
elseif isa(model.method, Variational)
model.output.file = path_prefix*"_variational_$(i).csv"
isfile(model.output.file) && rm(model.output.file)
elseif isa(model.method, Diagnose)
model.output.file = path_prefix*"_diagnose_$(i).csv"
isfile(model.output.file) && rm(model.output.file)
end
model.command[i] = cmdline(model)
end

try
if file_run_log
run(pipeline(par(model.command), stdout="$(path_prefix)_run.log"))
else
run(par(model.command))
end
catch e
println("\nAn error occurred while running the previously compiled Stan program.\n")
print("Please check the contents of file $(model.name)_run.log and the")
println("'command' field in the Stanmodel, e.g. stanmodel.command.\n")
error("Return code = -5")
end

local samplefiles = String[]
local ftype
# local cnames = String[]

if typeof(model.method) in [Sample, Variational]
if isa(model.method, Sample)
ftype = diagnostics ? "diagnostics" : "samples"
else
ftype = lowercase(string(typeof(model.method)))
end

for i in 1:model.nchains
model.id = i
model.data_file ="$(model.name)_$(i).data.R"
if init != Nothing
model.init_file = "$(model.name)_$(i).init.R"
end
if isa(model.method, Sample)
model.output.file = model.name*"_samples_$(i).csv"
isfile("$(model.name)_samples_$(i).csv") && rm("$(model.name)_samples_$(i).csv")
if diagnostics
model.output.diagnostic_file = model.name*"_diagnostics_$(i).csv"
isfile("$(model.name)_diagnostics_$(i).csv") && rm("$(model.name)_diagnostics_$(i).csv")
end
elseif isa(model.method, Optimize)
isfile("$(model.name)_optimize_$(i).csv") && rm("$(model.name)_optimize_$(i).csv")
model.output.file = model.name*"_optimize_$(i).csv"
elseif isa(model.method, Variational)
isfile("$(model.name)_variational_$(i).csv") && rm("$(model.name)_variational_$(i).csv")
model.output.file = model.name*"_variational_$(i).csv"
elseif isa(model.method, Diagnose)
isfile("$(model.name)_diagnose_$(i).csv") && rm("$(model.name)_diagnose_$(i).csv")
model.output.file = model.name*"_diagnose_$(i).csv"
end
model.command[i] = cmdline(model)
push!(samplefiles, "$(path_prefix)_$(ftype)_$(i).csv")
end

try
if file_run_log
run(pipeline(par(model.command), stdout="$(model.name)_run.log"))
else
run(par(model.command))
end
catch e
println("\nAn error occurred while running the previously compiled Stan program.\n")
print("Please check the contents of file $(model.name)_run.log and the")
println("'command' field in the Stanmodel, e.g. stanmodel.command.\n")
error("Return code = -5")
if summary
stan_summary(model, par(samplefiles), CmdStanDir=CmdStanDir)
end

local samplefiles = String[]
local ftype
# local cnames = String[]
(res, cnames) = read_samples(model, diagnostics)

if typeof(model.method) in [Sample, Variational]
if isa(model.method, Sample)
ftype = diagnostics ? "diagnostics" : "samples"
else
ftype = lowercase(string(typeof(model.method)))
end

for i in 1:model.nchains
push!(samplefiles, "$(model.name)_$(ftype)_$(i).csv")
end

if summary
stan_summary(model, par(samplefiles), CmdStanDir=CmdStanDir)
end

(res, cnames) = read_samples(model, diagnostics)

elseif isa(model.method, Optimize)
res, cnames = read_optimize(model)
elseif isa(model.method, Optimize)
res, cnames = read_optimize(model)

elseif isa(model.method, Diagnose)
res, cnames = read_diagnose(model)
elseif isa(model.method, Diagnose)
res, cnames = read_diagnose(model)

else
println("\nAn unknown method is specified in the call to stan().")
error("Return code = -10")
end
end # cd()
else
println("\nAn unknown method is specified in the call to stan().")
error("Return code = -10")
end

if model.output_format != :array
start_sample = 1
Expand Down
4 changes: 3 additions & 1 deletion src/main/stanmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ mutable struct Stanmodel
monitors::Vector{String}
data_file::String
command::Vector{Base.AbstractCmd}
object_file::String
method::Method
random::Random
init_file::String
Expand Down Expand Up @@ -155,6 +156,7 @@ function Stanmodel(

id::Int=0
data_file::String=""
object_file::String=""
init_file::String=""
cmdarray = fill(``, nchains)

Expand All @@ -173,7 +175,7 @@ function Stanmodel(
Stanmodel(name, nchains,
num_warmup, num_samples, thin,
id, model, model_file, monitors,
data_file, cmdarray, method, random,
data_file, cmdarray, object_file, method, random,
init_file, output, printsummary, pdir, tmpdir, output_format);
end

Expand Down
2 changes: 1 addition & 1 deletion src/utilities/create_cmd_line.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function cmdline(m)
cmd = ``
if isa(m, Stanmodel)
# Handle the model name field for unix and windows
cmd = @static Sys.isunix() ? `./$(getfield(m, :name))` : `cmd /c $(getfield(m, :name)).exe`
cmd = @static Sys.isunix() ? `$(getfield(m, :object_file))` : `cmd /c $(getfield(m, :object_file)).exe`

# Method (sample, optimize, variational and diagnose) specific portion of the model
cmd = `$cmd $(cmdline(getfield(m, :method)))`
Expand Down
4 changes: 2 additions & 2 deletions src/utilities/read_diagnose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ function read_diagnose(model::Stanmodel)
local sstr

for i in 1:model.nchains
if isfile("$(model.name)_$(res_type)_$(i).csv")
if isfile("$(model.object_file)_$(res_type)_$(i).csv")

## A result type file for chain i is present ##

if i == 1

# Extract cmdstan version

str = read("$(model.name)_$(res_type)_$(i).csv", String)
str = read("$(model.object_file)_$(res_type)_$(i).csv", String)
sstr = split(str)
tdict[:stan_version] = "$(parse(Int, sstr[4])).$(parse(Int, sstr[8])).$(parse(Int, sstr[12]))"
end
Expand Down
8 changes: 4 additions & 4 deletions src/utilities/read_optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ function read_optimize(model::Stanmodel)
tdict = Dict()

for i in 1:model.nchains
if isfile("$(model.name)_$(res_type)_$(i).csv")
if isfile("$(model.object_file)_$(res_type)_$(i).csv")

# A result type file for chain i is present ##
instream = open("$(model.name)_$(res_type)_$(i).csv")
instream = open("$(model.object_file)_$(res_type)_$(i).csv")
if i == 1
open("$(model.name)_$(res_type)_$(i).csv") do instream
open("$(model.object_file)_$(res_type)_$(i).csv") do instream
str = read(instream, String)
sstr = split(str)
tdict[:stan_major_version] = [parse(Int, sstr[4])]
Expand All @@ -43,7 +43,7 @@ function read_optimize(model::Stanmodel)
end

# After reopening the file, skip all comment lines
open("$(model.name)_$(res_type)_$(i).csv") do instream
open("$(model.object_file)_$(res_type)_$(i).csv") do instream
skipchars(isspace, instream, linecomment='#')
line = Unicode.normalize(readline(instream), newline2lf=true)

Expand Down
16 changes: 9 additions & 7 deletions src/utilities/read_samples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ function read_samples(m::Stanmodel, diagnostics=false, warmup_samples=false)
end

# Read .csv files created by each chain


cnames = String[]
a3d = Dict{String,Any}("a" => nothing)
for i in 1:m.nchains
if isfile("$(m.name)_$(ftype)_$(i).csv")
if isfile("$(m.object_file)_$(ftype)_$(i).csv")

instream = open("$(m.name)_$(ftype)_$(i).csv")
instream = open("$(m.object_file)_$(ftype)_$(i).csv")

# Skip initial set of commented lines, e.g. containing
# cmdstan version info, etc.
Expand All @@ -78,8 +80,8 @@ function read_samples(m::Stanmodel, diagnostics=false, warmup_samples=false)
# Preserve cnames and create a3d

if i == 1
global cnames = convert.(String, idx[indvec])
global a3d = fill(0.0, noofsamples, length(indvec), m.nchains)
append!(cnames, convert.(String, idx[indvec]))
a3d["a"] = fill(0.0, noofsamples, length(indvec), m.nchains)
end

# Read in the samples for all chains
Expand All @@ -93,12 +95,12 @@ function read_samples(m::Stanmodel, diagnostics=false, warmup_samples=false)
else
flds = parse.(Float64, split(strip(line), ","))
flds = reshape(flds[indvec], 1, length(indvec))
a3d[j,:,i] = flds
a3d["a"][j,:,i] = flds
end
end # read in samples
end # select next file if it exists
end # loop over all chains

(a3d, cnames)
(a3d["a"], cnames)

end # end of read_samples
2 changes: 1 addition & 1 deletion src/utilities/read_summary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ read_summary(m::Stanmodel)
"""
function read_summary(m::Stanmodel)

fname = "$(m.tmpdir)/$(m.name)_summary.csv"
fname = "$(m.object_file)_summary.csv"
!isfile(fname) && stan_summary(m)

df = CSV.read(fname, DataFrame; delim=",", comment="#")
Expand Down
Loading

0 comments on commit 3b6d5ad

Please sign in to comment.