Skip to content

Commit

Permalink
add diags and plotting for it in examples, VERY HACKY (#279)
Browse files Browse the repository at this point in the history
updates from hpc

rm prints

repeats for uq_for_edmf

add save jld2 and lines+series depending on repeats

add jld2 and log-scale

rm prints, typos

format
  • Loading branch information
odunbar authored Jul 13, 2024
1 parent 8d6e7c0 commit f2e95cf
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 62 deletions.
13 changes: 9 additions & 4 deletions examples/EDMF_data/plot_posterior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ using CalibrateEmulateSample.ParameterDistributions
# date = Date(year,month,day)

# 2-parameter calibration exp
exp_name = "ent-det-calibration"
date_of_run = Date(2023, 10, 5)
#exp_name = "ent-det-calibration"
#date_of_run = Date(2023, 10, 17)

# 5-parameter calibration exp
#exp_name = "ent-det-tked-tkee-stab-calibration"
#date_of_run = Date(2023,10,4)
exp_name = "ent-det-tked-tkee-stab-calibration"
date_of_run = Date(2024, 2, 2)

# Output figure read/write directory
figure_save_directory = joinpath(@__DIR__, "output", exp_name, string(date_of_run))
Expand Down Expand Up @@ -50,3 +50,8 @@ p = pairplot(data => (PairPlots.Scatter(),))
trans_p = pairplot(transformed_data => (PairPlots.Scatter(),))
save(density_filepath, p)
save(transformed_density_filepath, trans_p)

density_filepath = joinpath(figure_save_directory, "posterior_dist_comp.pdf")
transformed_density_filepath = joinpath(figure_save_directory, "posterior_dist_phys.pdf")
save(density_filepath, p)
save(transformed_density_filepath, trans_p)
148 changes: 106 additions & 42 deletions examples/EDMF_data/uq_for_edmf.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#include(joinpath(@__DIR__, "..", "ci", "linkfig.jl"))
#includef(joinpath(@__DIR__, "..", "ci", "linkfig.jl"))
PLOT_FLAG = false

# Import modules
using Distributions # probability distributions and associated functions
using LinearAlgebra
ENV["GKSwstype"] = "100"
using Plots
using CairoMakie
using Random
using JLD2
using NCDatasets
Expand All @@ -28,10 +29,10 @@ Random.seed!(rng_seed)
function main()

# 2-parameter calibration exp
exp_name = "ent-det-calibration"
#exp_name = "ent-det-calibration"

# 5-parameter calibration exp
#exp_name = "ent-det-tked-tkee-stab-calibration"
exp_name = "ent-det-tked-tkee-stab-calibration"


# Output figure save directory
Expand Down Expand Up @@ -120,6 +121,7 @@ function main()
for plot_i in 1:size(outputs, 1)
p = scatter(inputs_constrained[1, :], inputs_constrained[2, :], zcolor = outputs[plot_i, :])
savefig(p, joinpath(figure_save_directory, "output_" * string(plot_i) * ".png"))
savefig(p, joinpath(figure_save_directory, "output_" * string(plot_i) * ".pdf"))
end
println("finished plotting ensembles.")
end
Expand Down Expand Up @@ -201,52 +203,114 @@ function main()
cases = [
"GP", # diagonalize, train scalar GP, assume diag inputs
"RF-vector-svd-nonsep",
"RF-vector-nosvd-nonsep", # don't perform decorrelation
]
case = cases[2]

overrides = Dict(
"verbose" => true,
"train_fraction" => 0.95,
"scheduler" => DataMisfitController(terminate_at = 100),
"cov_sample_multiplier" => 0.5,
"n_iteration" => 5,
)
nugget = 0.01
rng_seed = 99330
rng = Random.MersenneTwister(rng_seed)
input_dim = size(get_inputs(input_output_pairs), 1)
output_dim = size(get_outputs(input_output_pairs), 1)
if case == "GP"

gppackage = Emulators.SKLJL()
pred_type = Emulators.YType()
mlt = GaussianProcess(
gppackage;
kernel = nothing, # use default squared exponential kernel
prediction_type = pred_type,
noise_learn = false,
case = cases[3]
n_repeats = 2

opt_diagnostics = []
emulators = []
for rep_idx in 1:n_repeats

overrides = Dict(
"verbose" => true,
"train_fraction" => 0.9, #95
"scheduler" => DataMisfitController(terminate_at = 1e5),
"cov_sample_multiplier" => 0.4,
"n_features_opt" => 200,
"n_iteration" => 15,
# "n_ensemble" => 20,
# "localization" => SEC(1.0, 0.01), # localization / sample error correction for small ensembles
)
elseif case ["RF-vector-svd-nonsep"]
kernel_structure = NonseparableKernel(LowRankFactor(3, nugget))
n_features = 500

mlt = VectorRandomFeatureInterface(
n_features,
input_dim,
output_dim,
rng = rng,
kernel_structure = kernel_structure,
optimizer_options = overrides,
nugget = 1e-10#1e-12#0.01
rng_seed = 99330
rng = Random.MersenneTwister(rng_seed)
input_dim = size(get_inputs(input_output_pairs), 1)
output_dim = size(get_outputs(input_output_pairs), 1)
decorrelate = true
if case == "GP"

gppackage = Emulators.SKLJL()
pred_type = Emulators.YType()
mlt = GaussianProcess(
gppackage;
kernel = nothing, # use default squared exponential kernel
prediction_type = pred_type,
noise_learn = false,
)
elseif case ["RF-vector-svd-nonsep"]
kernel_structure = NonseparableKernel(LowRankFactor(3, nugget))
n_features = 500

mlt = VectorRandomFeatureInterface(
n_features,
input_dim,
output_dim,
rng = rng,
kernel_structure = kernel_structure,
optimizer_options = overrides,
)
elseif case ["RF-vector-nosvd-nonsep"]
kernel_structure = NonseparableKernel(LowRankFactor(3, nugget))
n_features = 500

mlt = VectorRandomFeatureInterface(
n_features,
input_dim,
output_dim,
rng = rng,
kernel_structure = kernel_structure,
optimizer_options = overrides,
)
decorrelate = false
end

# Fit an emulator to the data
normalized = true

emulator = Emulator(
mlt,
input_output_pairs;
obs_noise_cov = truth_cov,
normalize_inputs = normalized,
decorrelate = decorrelate,
)

# Optimize the GP hyperparameters for better fit
optimize_hyperparameters!(emulator)
if case ["RF-vector-nosvd-nonsep", "RF-vector-svd-nonsep"]
push!(opt_diagnostics, get_optimizer(mlt)[1]) #length-1 vec of vec -> vec
end

for rep_idx in n_repeats
push!(emulators, emulator)
end
end
emulator = emulators[1]

# Fit an emulator to the data
normalized = true
# plot eki convergence plot
if length(opt_diagnostics) > 0
err_cols = reduce(hcat, opt_diagnostics) #error for each repeat as columns?

emulator = Emulator(mlt, input_output_pairs; obs_noise_cov = truth_cov, normalize_inputs = normalized)
#save data
error_filepath = joinpath(data_save_directory, "eki_conv_error.jld2")
save(error_filepath, "error", err_cols)

# Optimize the GP hyperparameters for better fit
optimize_hyperparameters!(emulator)
# print all repeats
f5 = Figure(resolution = (1.618 * 300, 300), markersize = 4)
ax_conv = Axis(f5[1, 1], xlabel = "Iteration", ylabel = "max-normalized error")
if n_repeats == 1
lines!(ax_conv, collect(1:size(err_cols, 1))[:], err_cols[:], solid_color = :blue) # If just one repeat
else
for idx in 1:size(err_cols, 1)
err_normalized = (err_cols' ./ err_cols[1, :])' # divide each series by the max, so all errors start at 1
series!(ax_conv, err_normalized', solid_color = :blue)
end
end
save(joinpath(figure_save_directory, "eki-conv_$(case).png"), f5, px_per_unit = 3)
save(joinpath(figure_save_directory, "eki-conv_$(case).pdf"), f5, px_per_unit = 3)

end

emulator_filepath = joinpath(data_save_directory, "emulator.jld2")
save(emulator_filepath, "emulator", emulator)
Expand Down
1 change: 1 addition & 0 deletions examples/Emulator/Ishigami/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GlobalSensitivityAnalysis = "1b10255b-6da3-57ce-9089-d24e8517b87e"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
42 changes: 38 additions & 4 deletions examples/Emulator/Ishigami/emulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using LinearAlgebra
using CalibrateEmulateSample.EnsembleKalmanProcesses
using CalibrateEmulateSample.Emulators
using CalibrateEmulateSample.DataContainers
using CalibrateEmulateSample.EnsembleKalmanProcesses.Localizers

using CairoMakie, ColorSchemes #for plots
seed = 2589456
Expand Down Expand Up @@ -81,9 +82,13 @@ function main()
case = cases[3]
decorrelate = true
nugget = Float64(1e-12)

overrides =
Dict("verbose" => true, "scheduler" => DataMisfitController(terminate_at = 1e4), "n_features_opt" => 200)
overrides = Dict(
"scheduler" => DataMisfitController(terminate_at = 1e4),
"n_features_opt" => 150,
"n_ensemble" => 30,
"n_iteration" => 20,
"accelerator" => NesterovAccelerator(),
)
if case == "Prior"
# don't do anything
overrides["n_iteration"] = 0
Expand All @@ -92,7 +97,7 @@ function main()

y_preds = []
result_preds = []

opt_diagnostics = []
for rep_idx in 1:n_repeats

# Build ML tools
Expand All @@ -118,6 +123,11 @@ function main()
emulator = Emulator(mlt, iopairs; obs_noise_cov = Γ * I, decorrelate = decorrelate)
optimize_hyperparameters!(emulator)

# get EKP errors - just stored in "optimizer" box for now
if case == "RF-scalar"
diag_tmp = reduce(hcat, get_optimizer(mlt)) # (n_iteration, dim_output=1) convergence for each scalar mode as cols
push!(opt_diagnostics, diag_tmp)
end
# predict on all Sobol points with emulator (example)
y_pred, y_var = predict(emulator, samples', transform_to_real = true)

Expand Down Expand Up @@ -186,6 +196,30 @@ function main()
save(joinpath(output_directory, "ishigami_slices_$(case).pdf"), f2, px_per_unit = 3)


if length(opt_diagnostics) > 0
err_cols = reduce(hcat, opt_diagnostics) #error for each repeat as columns?

#save
error_filepath = joinpath(output_directory, "eki_conv_error.jld2")
save(error_filepath, "error", err_cols)

# print all repeats
f3 = Figure(resolution = (1.618 * 300, 300), markersize = 4)
ax_conv = Axis(f3[1, 1], xlabel = "Iteration", ylabel = "Error")

if n_repeats == 1
lines!(ax_conv, collect(1:size(err_cols, 1))[:], err_cols[:], solid_color = :blue) # If just one repeat
else
for idx in 1:size(err_cols, 1)
err_normalized = (err_cols' ./ err_cols[1, :])' # divide each series by the max, so all errors start at 1
series!(ax_conv, err_normalized', solid_color = :blue)
end
end

save(joinpath(output_directory, "ishigami_eki-conv_$(case).png"), f3, px_per_unit = 3)
save(joinpath(output_directory, "ishigami_eki-conv_$(case).pdf"), f3, px_per_unit = 3)

end
end


Expand Down
45 changes: 37 additions & 8 deletions examples/Emulator/L63/emulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function main()
end

# rng
rng = MersenneTwister(1232434)
rng = MersenneTwister(1232435)

n_repeats = 20 # repeat exp with same data.
println("run experiment $n_repeats times")
Expand Down Expand Up @@ -92,20 +92,22 @@ function main()
# Emulate
cases = ["GP", "RF-scalar", "RF-scalar-diagin", "RF-svd-nonsep", "RF-nosvd-nonsep", "RF-nosvd-sep"]

case = cases[1]
case = cases[5]

nugget = Float64(1e-12)
u_test = []
u_hist = []
train_err = []
opt_diagnostics = []

for rep_idx in 1:n_repeats

rf_optimizer_overrides = Dict(
"scheduler" => DataMisfitController(terminate_at = 1e4),
"cov_sample_multiplier" => 0.5,
"n_features_opt" => 400,
"n_iteration" => 30,
"accelerator" => ConstantStepNesterovAccelerator(),
"cov_sample_multiplier" => 1.0,
"n_features_opt" => 200,
"n_iteration" => 10, #30
"accelerator" => NesterovAccelerator(),
)

# Build ML tools
Expand Down Expand Up @@ -170,6 +172,11 @@ function main()
emulator = Emulator(mlt, iopairs; obs_noise_cov = Γy, decorrelate = decorrelate)
optimize_hyperparameters!(emulator)

# diagnostics
if case == "RF-nosvd-nonsep"
push!(opt_diagnostics, get_optimizer(mlt)[1]) #length-1 vec of vec -> vec
end


# Predict with emulator
u_test_tmp = zeros(3, length(xspan_test))
Expand Down Expand Up @@ -252,6 +259,30 @@ function main()
JLD2.save(joinpath(output_directory, case * "_l63_histdata.jld2"), "solhist", solhist, "uhist", u_hist)
JLD2.save(joinpath(output_directory, case * "_l63_testdata.jld2"), "solplot", solplot, "uplot", u_test)

# plot eki convergence plot
if length(opt_diagnostics) > 0
err_cols = reduce(hcat, opt_diagnostics) #error for each repeat as columns?

#save
error_filepath = joinpath(output_directory, "eki_conv_error.jld2")
save(error_filepath, "error", err_cols)

# print all repeats
f5 = Figure(resolution = (1.618 * 300, 300), markersize = 4)
ax_conv = Axis(f5[1, 1], xlabel = "Iteration", ylabel = "max-normalized error", yscale = log10)
if n_repeats == 1
lines!(ax_conv, collect(1:size(err_cols, 1))[:], err_cols[:], solid_color = :blue) # If just one repeat
else
for idx in 1:size(err_cols, 1)
err_normalized = (err_cols' ./ err_cols[1, :])' # divide each series by the max, so all errors start at 1
series!(ax_conv, err_normalized', solid_color = :blue)
end
end
save(joinpath(output_directory, "l63_eki-conv_$(case).png"), f5, px_per_unit = 3)
save(joinpath(output_directory, "l63_eki-conv_$(case).pdf"), f5, px_per_unit = 3)

end

# compare marginal histograms to truth - rough measure of fit
sol_cdf = sort(solhist, dims = 2)

Expand All @@ -278,8 +309,6 @@ function main()
lines!(axy, sol_cdf[2, :], unif_samples, color = (:orange, 1.0), linewidth = 4)
lines!(axz, sol_cdf[3, :], unif_samples, color = (:orange, 1.0), linewidth = 4)



# save
save(joinpath(output_directory, case * "_l63_cdfs.png"), f4, px_per_unit = 3)
save(joinpath(output_directory, case * "_l63_cdfs.pdf"), f4, pt_per_unit = 3)
Expand Down
1 change: 0 additions & 1 deletion src/MarkovChainMonteCarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ function AbstractMCMC.bundle_samples(
)
# Turn all the transitions into a vector-of-vectors.
vals = [vcat(t.params, t.log_density, t.accepted) for t in ts]

# Check if we received any parameter names.
if ismissing(param_names)
param_names = [Symbol(:param_, i) for i in 1:length(keys(ts[1].params))]
Expand Down
Loading

0 comments on commit f2e95cf

Please sign in to comment.