Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yaml in #128

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ ShiftedArrays = "1277b4bf-5013-50f5-be3d-901d8477a67a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

Expand Down Expand Up @@ -66,6 +67,7 @@ SciMLSensitivity = "7"
ShiftedArrays = "2"
Statistics = "1.10"
Tullio = "0.3"
YAML = "0.4.12"
Zygote = "0.6"
cuDNN = "1"
julia = "1"
Expand Down
59 changes: 17 additions & 42 deletions simulations/Benchmark/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ using ParameterSchedulers
using Random
using SparseArrays

########################################################################## #src
# Read the configuration file
conf = read_config("test_conf.yaml")
########################################################################## #src

# ## Random number seeds
Expand All @@ -86,19 +89,14 @@ using SparseArrays
#
# We define all the seeds here.

seeds = (;
dns = 123, # Initial conditions
θ_start = 234, # Initial CNN parameters
prior = 345, # A-priori training batch selection
post = 456, # A-posteriori training batch selection
)
seeds = load_seeds(conf)

########################################################################## #src

# ## Hardware selection

# Precision
T = Float32
T = eval(Meta.parse(conf["T"]))

# Device
if CUDA.functional()
Expand All @@ -118,45 +116,31 @@ else
clean() = nothing
end

#add backend to conf
conf["params"]["backend"] = backend

########################################################################## #src

# ## Data generation
#
# Create filtered DNS data for training, validation, and testing.

# Parameters
params = (;
D = 2,
lims = (T(0), T(1)),
Re = T(6e3),
tburn = T(0.5),
tsim = T(5),
savefreq = 10,
ndns = 2048,
nles = [128,],
filters = (FaceAverage(),),
backend,
icfunc = (setup, psolver, rng) -> random_field(setup, T(0); kp = 20, psolver, rng),
method = RKMethods.Wray3(; T),
bodyforce = (dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y),
issteadybodyforce = true,
processors = (; log = timelogger(; nupdate = 100)),
Δt = T(1e-3),
)
params = load_params(conf)

# DNS seeds
ntrajectory = 8
ntrajectory = conf["ntrajectory"]
dns_seeds = splitseed(seeds.dns, ntrajectory)
dns_seeds_train = dns_seeds[1:ntrajectory-2]
dns_seeds_valid = dns_seeds[ntrajectory-1:ntrajectory-1]
dns_seeds_test = dns_seeds[ntrajectory:ntrajectory]

# Create data
docreatedata = false
docreatedata = conf["docreatedata"]
docreatedata && createdata(; params, seeds = dns_seeds, outdir, taskid)

# Computational time
docomp = true
docomp = conf["docomp"]
docomp && let
comptime, datasize = 0.0, 0.0
for seed in dns_seeds
Expand Down Expand Up @@ -184,16 +168,7 @@ setups = map(nles -> getsetup(; params, nles), params.nles);
# All training sessions will start from the same θ₀
# for a fair comparison.

closure, θ_start, st = CoupledNODE.cnn(;
T = T,
D = params.D,
data_ch = params.D,
radii = [2, 2, 2, 2,2],
channels = [24,24,24,24, 2],
activations = [tanh,tanh,tanh,tanh, identity],
use_bias = [true,true, true,true, false],
rng = Xoshiro(seeds.θ_start),
)
closure, θ_start, st = load_model(conf)
# same model structure in INS format
closure_INS, θ_INS = cnn(;
setup = setups[1],
Expand Down Expand Up @@ -235,8 +210,8 @@ end

# Train
let
dotrain = true
nepoch = 200
dotrain = conf["priori"]["dotrain"]
nepoch = conf["priori"]["nepoch"]
dotrain && trainprior(;
params,
priorseed = seeds.prior,
Expand All @@ -248,8 +223,8 @@ let
closure,
θ_start,
st,
opt = ClipAdam = OptimiserChain(Adam(T(1.0e-2)), ClipGrad(1)),
batchsize = 64,
opt = eval(Meta.parse(conf["priori"]["opt"])),
batchsize = conf["priori"]["batchsize"],
nepoch,
)
end
Expand Down
38 changes: 38 additions & 0 deletions simulations/Benchmark/conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
docreatedata: true
docomp: true
ntrajectory: 8
T: "Float32"
params:
D: 2
lims: [0.0, 1.0]
Re: 6000.0
tburn: 0.5
tsim: 5.0
savefreq: 10
ndns: 2048
nles: [128]
filters: ["FaceAverage()"]
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
method: "RKMethods.Wray3(; T)"
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
issteadybodyforce: true
processors: "(; log = timelogger(; nupdate=100))"
Δt: 0.001
seeds:
dns: 123
θ_start: 234
prior: 345
post: 456
closure:
name: "CNN0"
type: cnn
radii: [2, 2, 2, 2, 2]
channels: [24, 24, 24, 24, 2]
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
use_bias: [true, true, true, true, false]
rng: "Xoshiro(seeds.θ_start)"
priori:
dotrain: true
nepoch: 100
batchsize: 32
opt: "OptimiserChain(Adam(T(1.0e-2)), ClipGrad(1))"
1 change: 1 addition & 0 deletions src/CoupledNODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include("loss/loss_posteriori.jl")

include("equations/NavierStokes_utils.jl")

include("io.jl")
include("utils.jl")
include("train.jl")

Expand Down
96 changes: 96 additions & 0 deletions src/io.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
using IncompressibleNavierStokes
using NeuralClosure
using Random
using YAML

function read_config(filename)
conf = YAML.load_file(filename)
return conf
end

function load_params(conf)
data = conf["params"]
T = eval(Meta.parse(conf["T"]))
function eval_field(field, T)
if field isa String
field = "T=$T; $field"
return eval(Meta.parse(field))
else
return field
end
end

params = (;
D = data["D"],
lims = (T(data["lims"][1]), T(data["lims"][2])),
Re = T(data["Re"]),
tburn = T(data["tburn"]),
tsim = T(data["tsim"]),
savefreq = data["savefreq"],
ndns = data["ndns"],
nles = data["nles"],
filters = tuple(map(f -> eval_field(f, T), data["filters"])...),
backend = eval_field(data["backend"], T),
icfunc = eval_field(data["icfunc"], T),
method = eval_field(data["method"], T),
bodyforce = eval_field(data["bodyforce"], T),
processors = eval_field(data["processors"], T),
issteadybodyforce = data["issteadybodyforce"],
Δt = T(data["Δt"])
)

return params
end

function load_seeds(conf)
data = conf["seeds"]
seeds = (;
dns = data["dns"],
θ_start = data["θ_start"],
prior = data["prior"],
post = data["post"]
)
return seeds
end

function load_model(conf)
model_type = conf["closure"]["type"]
if model_type == "cnn"
return load_cnn_params(conf)
else
error("Model type not supported")
end
end

function load_cnn_params(conf)
T = eval(Meta.parse(conf["T"]))
D = conf["params"]["D"]

# Evaluate activations and rng
function eval_field(field, s = nothing)
if field isa String
if s != nothing
field = "seeds=$s; $field"
end
return eval(Meta.parse(field))
else
return field
end
end

# Construct the cnn call
data = conf["closure"]
seeds = load_seeds(conf)
closure, θ_start, st = CoupledNODE.cnn(
T = T,
D = D,
data_ch = D,
radii = data["radii"],
channels = data["channels"],
activations = map(eval_field, data["activations"]),
use_bias = data["use_bias"],
rng = eval_field(data["rng"], seeds)
)

return closure, θ_start, st
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The file will be automatically included inside a `@testset` with title "Title Fo
=#
for (root, dirs, files) in walkdir(@__DIR__)
for file in files
if isnothing(match(r"^test_.*\.jl$", file))
if isnothing(match(r"^test_*\.jl$", file))
continue
end
title = titlecase(replace(splitext(file[6:end])[1], "-" => " "))
Expand Down
38 changes: 38 additions & 0 deletions test/test_conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
docreatedata: true
docomp: true
ntrajectory: 8
T: "Float32"
params:
D: 2
lims: [0.0, 1.0]
Re: 6000.0
tburn: 0.5
tsim: 5.0
savefreq: 10
ndns: 2048
nles: [128]
filters: ["FaceAverage()"]
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
method: "RKMethods.Wray3(; T)"
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
issteadybodyforce: true
processors: "(; log = timelogger(; nupdate=100))"
Δt: 0.001
seeds:
dns: 123
θ_start: 234
prior: 345
post: 456
closure:
name: "CNN0"
type: cnn
radii: [2, 2, 2, 2, 2]
channels: [24, 24, 24, 24, 2]
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
use_bias: [true, true, true, true, false]
rng: "Xoshiro(seeds.θ_start)"
priori:
dotrain: true
nepoch: 100
batchsize: 32
opt: "OptimiserChain(Adam(T(1.0e-2)), ClipGrad(1))"
Loading
Loading