Skip to content

Commit

Permalink
added RF emulator
Browse files Browse the repository at this point in the history
update project toml

examples for RF

shuffle data

update example

example to produce comparable figs

updates for compatability with CES 0.2.0 and RF 0.1.0

vector random feature support

added fixes to ensure CES pipeline runs

regularization and lorenz example

format

allows training with fewer features than data

VRFI with SVD, and cholesky options

feature num dep on n

GCM example

replace data

multithreading and rng

add ProgressBars

remove high-level threading for now (takes place within LinAlg solvers)

sbatch script

truth at some points

increased number optimization features default

bugfix reg matrix argument

initial tik-reg for EKI

working TEKI

0 default eki, small tweaks

add logdet complexity

more consistent adding of definiteness

chol/svd

add logdet to scalar learning

shape bug

logdetI

unite common functions in Random Feature, expand Scalar feature learning

extend reverse svd for covs

add diag terms to MatrixNormal description, default to diagonal regularizations rather than pos-def

add diagonal option

trimmed, and added const hp for diag cov

compat with svd truncation, and more standard posdef corrections

added scaling to complexity data

change scalar interface

lorenz 2d statsplot

combine all MLT examples into this

improved interfacing, unification and initial unit testing

condensed into emulate_sample

simplify scalar interface

bug

improved vector interface

reg should be multiplicative! fixed

small edits

update ess.jl

MSE on next ensemble, add input-diag case

inflation

optimizer defaults and cov representation

inflation vec

inflation

utility for ensembles

test pass with new defaults and cov structure

format

format

add RF tests

GP test fails resolved
  • Loading branch information
odunbar committed Apr 6, 2023
1 parent 328a26c commit ecd1eca
Show file tree
Hide file tree
Showing 27 changed files with 3,831 additions and 226 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomFeatures = "36c3bae2-c0c3-419d-b3b4-eebadd35c5e5"
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

Expand All @@ -31,9 +34,10 @@ EnsembleKalmanProcesses = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 0.14"
GaussianProcesses = "0.12"
MCMCChains = "4.14, 5"
PyCall = "1.93"
RandomFeatures = "0.2"
ScikitLearn = "0.6"
StatsBase = "0.33"
julia = "1.6"
julia = "1.6, 1.7, 1.8"

[extras]
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
12 changes: 6 additions & 6 deletions examples/Emulator/GaussianProcess/plot_GP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ if !isdir(output_directory)
end

#create the machine learning tools: Gaussian Process
gppackage = GPJL()
gppackage = SKLJL()
pred_type = YType()
gaussian_process = GaussianProcess(gppackage, noise_learn = true)
gaussian_process = GaussianProcess(gppackage, noise_learn = false)

# Generate training data (x-y pairs, where x ∈ ℝ ᵖ, y ∈ ℝ ᵈ)
# x = [x1, x2]: inputs/predictors/features/parameters
Expand All @@ -92,7 +92,7 @@ gx[2, :] = g2x

# Add noise η
μ = zeros(d)
Σ = 0.1 * [[0.8, 0.0] [0.0, 0.5]] # d x d
Σ = 0.1 * [[0.8, 0.1] [0.1, 0.5]] # d x d
noise_samples = rand(MvNormal(μ, Σ), n)
# y = G(x) + η
Y = gx .+ noise_samples
Expand Down Expand Up @@ -182,9 +182,9 @@ println("GP trained")

# Plot mean and variance of the predicted observables y1 and y2
# For this, we generate test points on a x1-x2 grid.
n_pts = 50
x1 = range(0.0, stop = 2 * π, length = n_pts)
x2 = range(0.0, stop = 2 * π, length = n_pts)
n_pts = 200
x1 = range(0.0, stop = (4.0 / 5.0) * 2 * π, length = n_pts)
x2 = range(0.0, stop = (4.0 / 5.0) * 2 * π, length = n_pts)
X1, X2 = meshgrid(x1, x2)
# Input for predict has to be of size N_samples x input_dim
inputs = permutedims(hcat(X1[:], X2[:]), (2, 1))
Expand Down
15 changes: 15 additions & 0 deletions examples/Emulator/RandomFeature/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[deps]
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
FiniteDiff = "~2.10"
julia = "~1.6"
257 changes: 257 additions & 0 deletions examples/Emulator/RandomFeature/scalar_optimize_and_plot_RF.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Reference the in-tree version of CalibrateEmulateSample on Julias load path
include(joinpath(@__DIR__, "..", "..", "ci", "linkfig.jl"))

# Import modules
using Random
using StableRNGs
using Distributions
using Statistics
using LinearAlgebra
using CalibrateEmulateSample.Emulators
using CalibrateEmulateSample.DataContainers
using CalibrateEmulateSample.ParameterDistributions

case = "scalar"
println("running case $case")

plot_flag = true
if plot_flag
using Plots
gr(size = (1500, 700))
Plots.scalefontsizes(1.3)
font = Plots.font("Helvetica", 18)
fontdict = Dict(:guidefont => font, :xtickfont => font, :ytickfont => font, :legendfont => font)

end

function meshgrid(vx::AbstractVector{T}, vy::AbstractVector{T}) where {T}
m, n = length(vy), length(vx)
gx = reshape(repeat(vx, inner = m, outer = 1), m, n)
gy = reshape(repeat(vy, inner = 1, outer = n), m, n)

return gx, gy
end
rng_seed = 41
Random.seed!(rng_seed)
output_directory = joinpath(@__DIR__, "output")
if !isdir(output_directory)
mkdir(output_directory)
end

#problem
n = 100 # number of training points
p = 2 # input dim
d = 2 # output dim
X = 2.0 * π * rand(p, n)
# G(x1, x2)
g1x = sin.(X[1, :]) .+ cos.(X[2, :])
g2x = sin.(X[1, :]) .- cos.(X[2, :])
gx = zeros(2, n)
gx[1, :] = g1x
gx[2, :] = g2x

# Add noise η
μ = zeros(d)
Σ = 0.1 * [[0.8, 0.1] [0.1, 0.5]] # d x d
noise_samples = rand(MvNormal(μ, Σ), n)
# y = G(x) + η
Y = gx .+ noise_samples

iopairs = PairedDataContainer(X, Y, data_are_columns = true)
@assert get_inputs(iopairs) == X
@assert get_outputs(iopairs) == Y

#plot training data with and without noise
if plot_flag
p1 = plot(
X[1, :],
X[2, :],
g1x,
st = :surface,
camera = (30, 60),
c = :cividis,
xlabel = "x1",
ylabel = "x2",
zguidefontrotation = 90,
)

figpath = joinpath(output_directory, "RF_" * case * "_observed_y1nonoise.png")
savefig(figpath)

p2 = plot(
X[1, :],
X[2, :],
g2x,
st = :surface,
camera = (30, 60),
c = :cividis,
xlabel = "x1",
ylabel = "x2",
zguidefontrotation = 90,
)
figpath = joinpath(output_directory, "RF_" * case * "_observed_y2nonoise.png")
savefig(figpath)

p3 = plot(
X[1, :],
X[2, :],
Y[1, :],
st = :surface,
camera = (30, 60),
c = :cividis,
xlabel = "x1",
ylabel = "x2",
zguidefontrotation = 90,
)
figpath = joinpath(output_directory, "RF_" * case * "_observed_y1.png")
savefig(figpath)

p4 = plot(
X[1, :],
X[2, :],
Y[2, :],
st = :surface,
camera = (30, 60),
c = :cividis,
xlabel = "x1",
ylabel = "x2",
zguidefontrotation = 90,
)
figpath = joinpath(output_directory, "RF_" * case * "_observed_y2.png")
savefig(figpath)

end

# setup random features
n_features = 200

srfi = ScalarRandomFeatureInterface(n_features, p)
emulator = Emulator(srfi, iopairs, obs_noise_cov = Σ, normalize_inputs = true)
println("build RF with $n training points and $(n_features) random features.")

optimize_hyperparameters!(emulator) # although RF already optimized

# Plot mean and variance of the predicted observables y1 and y2
# For this, we generate test points on a x1-x2 grid.
n_pts = 200
x1 = range(0.0, stop = 4.0 / 5.0 * 2 * π, length = n_pts)
x2 = range(0.0, stop = 4.0 / 5.0 * 2 * π, length = n_pts)
X1, X2 = meshgrid(x1, x2)
# Input for predict has to be of size N_samples x input_dim
inputs = permutedims(hcat(X1[:], X2[:]), (2, 1))

rf_mean, rf_cov = predict(emulator, inputs, transform_to_real = true)
println("end predictions at ", n_pts * n_pts, " points")


#plot predictions
for y_i in 1:d
rf_var_temp = [diag(rf_cov[j]) for j in 1:length(rf_cov)] # (40000,)
rf_var = permutedims(vcat([x' for x in rf_var_temp]...), (2, 1)) # 2 x 40000

mean_grid = reshape(rf_mean[y_i, :], n_pts, n_pts) # 2 x 40000
if plot_flag
p5 = plot(
x1,
x2,
mean_grid,
st = :surface,
camera = (30, 60),
c = :cividis,
xlabel = "x1",
ylabel = "x2",
zlabel = "mean of y" * string(y_i),
zguidefontrotation = 90,
)
end
var_grid = reshape(rf_var[y_i, :], n_pts, n_pts)
if plot_flag
p6 = plot(
x1,
x2,
var_grid,
st = :surface,
camera = (30, 60),
c = :cividis,
xlabel = "x1",
ylabel = "x2",
zlabel = "var of y" * string(y_i),
zguidefontrotation = 90,
)

plot(p5, p6, layout = (1, 2), legend = false)

savefig(joinpath(output_directory, "RF_" * case * "_y" * string(y_i) * "_predictions.png"))
end
end

# Plot the true components of G(x1, x2)
g1_true = sin.(inputs[1, :]) .+ cos.(inputs[2, :])
g1_true_grid = reshape(g1_true, n_pts, n_pts)
if plot_flag
p7 = plot(
x1,
x2,
g1_true_grid,
st = :surface,
camera = (30, 60),
c = :cividis,
xlabel = "x1",
ylabel = "x2",
zlabel = "sin(x1) + cos(x2)",
zguidefontrotation = 90,
)
savefig(joinpath(output_directory, "RF_" * case * "_true_g1.png"))
end

g2_true = sin.(inputs[1, :]) .- cos.(inputs[2, :])
g2_true_grid = reshape(g2_true, n_pts, n_pts)
if plot_flag
p8 = plot(
x1,
x2,
g2_true_grid,
st = :surface,
camera = (30, 60),
c = :cividis,
xlabel = "x1",
ylabel = "x2",
zlabel = "sin(x1) - cos(x2)",
zguidefontrotation = 90,
)
g_true_grids = [g1_true_grid, g2_true_grid]

savefig(joinpath(output_directory, "RF_" * case * "_true_g2.png"))

end

# Plot the difference between the truth and the mean of the predictions
for y_i in 1:d

# Reshape rf_cov to size N_samples x output_dim
rf_var_temp = [diag(rf_cov[j]) for j in 1:length(rf_cov)] # (40000,)
rf_var = permutedims(vcat([x' for x in rf_var_temp]...), (2, 1)) # 40000 x 2

mean_grid = reshape(rf_mean[y_i, :], n_pts, n_pts)
var_grid = reshape(rf_var[y_i, :], n_pts, n_pts)
# Compute and plot 1/variance * (truth - prediction)^2

if plot_flag
zlabel = "1/var * (true_y" * string(y_i) * " - predicted_y" * string(y_i) * ")^2"

p9 = plot(
x1,
x2,
sqrt.(1.0 ./ var_grid .* (g_true_grids[y_i] .- mean_grid) .^ 2),
st = :surface,
camera = (30, 60),
c = :magma,
zlabel = zlabel,
xlabel = "x1",
ylabel = "x2",
zguidefontrotation = 90,
)

savefig(joinpath(output_directory, "RF_" * case * "_y" * string(y_i) * "_difference_truth_prediction.png"))
end
end
Loading

0 comments on commit ecd1eca

Please sign in to comment.