Skip to content

Commit

Permalink
Check if data already gen
Browse files Browse the repository at this point in the history
  • Loading branch information
SCiarella committed Dec 20, 2024
1 parent 4f7f496 commit e6ca6e8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
10 changes: 8 additions & 2 deletions simulations/Benchmark/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ using IncompressibleNavierStokes
using NeuralClosure
using CoupledNODE
NS = Base.get_extension(CoupledNODE, :NavierStokes)
#conf = NS.read_config("configs/conf.yaml")
conf = NS.read_config(ENV["CONF_FILE"])
conf =nothing
try
conf = NS.read_config(ENV["CONF_FILE"])
@info "Reading configuration file from ENV"
catch
@info "Reading configuration file from default"
conf = NS.read_config("configs/conf.yaml")
end
########################################################################## #src

# Choose where to put output
Expand Down
8 changes: 6 additions & 2 deletions simulations/Benchmark/src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ createdata(; params, seeds, outdir, taskid) =
ispath(datadir) || mkpath(datadir)
push!(filenames, f)
end
if isfile(filenames[1])
@info "Data file $(filenames[1]) already exists. Skipping."
continue
end
data = create_les_data(; params..., rng = Xoshiro(seed), filenames, Δt = params.Δt)
@info("Trajectory info:",
data[1].comptime/60,
Expand Down Expand Up @@ -102,7 +106,7 @@ function trainprior(;
θ = device(copy(θ_start))
dataloader_prior = NS.create_dataloader_prior(
io_train[itotal]; batchsize = batchsize,
rng = Random.Xoshiro(dns_seeds_train[itotal]))
rng = Random.Xoshiro(dns_seeds_train[itotal]), device = device)
train_data_priori = dataloader_prior()
loss_priori_lux(closure, θ, st, train_data_priori)
loss = loss_priori_lux
Expand Down Expand Up @@ -218,7 +222,7 @@ function trainpost(;
θ = device(copy(θ_start[itotal]))
dataloader_post = NS.create_dataloader_posteriori(
io_train[itotal]; nunroll = nunroll,
rng = Random.Xoshiro(dns_seeds_train[itotal]))
rng = Random.Xoshiro(dns_seeds_train[itotal]), device = device)

dudt_nn = NS.create_right_hand_side_with_closure(
setup[1], psolver, closure, st)
Expand Down

0 comments on commit e6ca6e8

Please sign in to comment.