diff --git a/Project.toml b/Project.toml index 91690e9..bf71bcb 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" -Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -38,12 +37,13 @@ LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" [sources] NeuralClosure = {rev = "main", url = "https://github.com/DEEPDIP-project/NeuralClosure.jl.git"} [extensions] -CoupledNODECUDA = ["CUDA", "cuDNN", "LuxCUDA"] +CoupledNODECUDA = ["CUDA"] NavierStokes = ["IncompressibleNavierStokes", "NeuralClosure"] [compat] diff --git a/ext/CoupledNODECUDA.jl b/ext/CoupledNODECUDA.jl index 46bc69d..0fcff7a 100644 --- a/ext/CoupledNODECUDA.jl +++ b/ext/CoupledNODECUDA.jl @@ -1,6 +1,11 @@ module CoupledNODECUDA +using CoupledNODE using CUDA: CUDA -ArrayType = CUDA.functional() ? CUDA.CuArray : Array +function ArrayType() + return CUDA.functional() ? CUDA.CuArray : Array +end + +allowscalar = deepcopy(CUDA.allowscalar) end diff --git a/ext/NavierStokes/callback.jl b/ext/NavierStokes/callback.jl index d9e4460..4015f06 100644 --- a/ext/NavierStokes/callback.jl +++ b/ext/NavierStokes/callback.jl @@ -38,11 +38,21 @@ The callback function is used during training to compute and log validation and """ function create_callback( model, θ, val_io_data, loss_function, st; - callbackstate = (; - θmin = θ, loss_min = eltype(θ)(Inf), lhist_val = [], - lhist_train = [], lhist_nomodel = []), - nunroll = nothing, batch_size = nothing, rng = Random.Xoshiro(123), do_plot = true, - plot_train = true, plot_every = 10, average_window = 25, device = identity) + callbackstate = nothing, + nunroll = nothing, + batch_size = nothing, + rng = Random.Xoshiro(123), + do_plot = true, + plot_train = true, + plot_every = 10, + average_window = 25, + device = identity, + figfile = nothing) + if callbackstate === nothing + # Initialize the callback state + callbackstate = (; θmin = θ, loss_min = eltype(θ)(Inf), lhist_val = [], + lhist_train = [], lhist_nomodel = []) + end if nunroll === nothing && batch_size === nothing error("Either nunroll or batch_size must be provided") elseif nunroll !== nothing @@ -95,6 +105,10 @@ function create_callback( CairoMakie.axislegend(ax) display(fig) + + if figfile !== nothing + CairoMakie.save(figfile, fig) + end end end callbackstate diff --git a/simulations/Benchmark/.gitignore b/simulations/Benchmark/.gitignore new file mode 100644 index 0000000..f9412f5 --- /dev/null +++ b/simulations/Benchmark/.gitignore @@ -0,0 +1,2 @@ +slurm* +update_julia.out diff --git a/simulations/Benchmark/benchmark.jl b/simulations/Benchmark/benchmark.jl index ec011e2..89bc61b 100644 --- a/simulations/Benchmark/benchmark.jl +++ b/simulations/Benchmark/benchmark.jl @@ -4,14 +4,35 @@ if false #src end #src @info "Script started" +@info VERSION + +using Pkg +@info Pkg.status() # Color palette for consistent theme throughout paper palette = (; color = ["#3366cc", "#cc0000", "#669900", "#ff9900"]) +########################################################################## #src +# Read the configuration file +using IncompressibleNavierStokes +using NeuralClosure +using CoupledNODE +NS = Base.get_extension(CoupledNODE, :NavierStokes) +global conf +try + conf = NS.read_config(ENV["CONF_FILE"]) + @info "Reading configuration file from ENV" +catch + @info "Reading configuration file from default" + conf = NS.read_config("configs/conf.yaml") +end +########################################################################## #src + # Choose where to put output basedir = haskey(ENV, "DEEPDIP") ? ENV["DEEPDIP"] : @__DIR__ outdir = joinpath(basedir, "output", "kolmogorov") closure_name = conf["closure"]["name"] +outdir_model = joinpath(outdir, closure_name) plotdir = joinpath(outdir, closure_name, "plots") logdir = joinpath(outdir, closure_name, "logs") ispath(outdir) || mkpath(outdir) @@ -51,31 +72,22 @@ setsnelliuslogger(logfile) using Accessors using Adapt -# using GLMakie using CairoMakie -using CoupledNODE using CoupledNODE: loss_priori_lux, create_loss_post_lux -using CoupledNODE.NavierStokes: create_right_hand_side, create_right_hand_side_with_closure using CUDA using DifferentialEquations -using IncompressibleNavierStokes using IncompressibleNavierStokes.RKMethods using JLD2 using LaTeXStrings using LinearAlgebra using Lux using LuxCUDA -using NeuralClosure using NNlib using Optimisers using ParameterSchedulers using Random using SparseArrays -########################################################################## #src -# Read the configuration file -conf = read_config("test_conf.yaml") -########################################################################## #src # ## Random number seeds # @@ -90,7 +102,7 @@ conf = read_config("test_conf.yaml") # # We define all the seeds here. -seeds = load_seeds(conf) +seeds = NS.load_seeds(conf) ########################################################################## #src @@ -116,9 +128,10 @@ else device = identity clean() = nothing end +conf["params"]["backend"] = deepcopy(backend) +@info backend +@info CUDA.versioninfo() -#add backend to conf -conf["params"]["backend"] = backend ########################################################################## #src @@ -127,7 +140,8 @@ conf["params"]["backend"] = backend # Create filtered DNS data for training, validation, and testing. # Parameters -params = load_params(conf) +params = NS.load_params(conf) +@info params # DNS seeds ntrajectory = conf["ntrajectory"] @@ -139,6 +153,7 @@ dns_seeds_test = dns_seeds[ntrajectory:ntrajectory] # Create data docreatedata = conf["docreatedata"] docreatedata && createdata(; params, seeds = dns_seeds, outdir, taskid) +@info "Data generated" # Computational time docomp = conf["docomp"] @@ -169,7 +184,7 @@ setups = map(nles -> getsetup(; params, nles), params.nles); # All training sessions will start from the same θ₀ # for a fair comparison. -closure, θ_start, st = load_model(conf) +closure, θ_start, st = NS.load_model(conf) # same model structure in INS format closure_INS, θ_INS = cnn(; setup = setups[1], @@ -179,7 +194,7 @@ closure_INS, θ_INS = cnn(; use_bias = [true,true, true,true, false], rng = Xoshiro(seeds.θ_start), ) -#@assert θ_start == θ_INS +@assert θ_start == θ_INS @info "Initialized CNN with $(length(θ_start)) parameters" @@ -207,8 +222,6 @@ end # Save parameters to disk after each run. # Plot training progress (for a validation data batch). -# Parameter save files - # Train let dotrain = conf["priori"]["dotrain"] @@ -288,18 +301,18 @@ end # Save parameters to disk after each combination. # Plot training progress (for a validation data batch). # -# The time stepper `RKProject` allows for choosing when to project. +# [INS] The time stepper `RKProject` allows for choosing when to project. +# [CNODE] Only DCF (last) is supported since it appears to be the best one. -# First = DIF (Bad!) -# Last = DCF -projectorders = (ProjectOrder.Last, ) -# I think that in practice we can only do DCF +projectorders = eval(Meta.parse(conf["posteriori"]["projectorders"])) nprojectorders = length(projectorders) +@assert nprojectorders == 1 "Only DCF should be done" # Train let - dotrain = true - nepoch = 100 + dotrain = conf["posteriori"]["dotrain"] + nepoch = conf["posteriori"]["nepoch"] + nepoch = 40 dotrain && trainpost(; params, projectorders, @@ -309,14 +322,14 @@ let postseed = seeds.post, dns_seeds_train, dns_seeds_valid, - nunroll = 5, + nunroll = conf["posteriori"]["nunroll"], closure, θ_start = θ_cnn_prior, st, - opt = ClipAdam = OptimiserChain(Adam(T(1.0e-3)), ClipGrad(1)), - nunroll_valid = 10, + opt = eval(Meta.parse(conf["posteriori"]["opt"])), + nunroll_valid = conf["posteriori"]["nunroll_valid"], nepoch, - dt = T(1e-3), + dt = eval(Meta.parse(conf["posteriori"]["dt"])), ) end @@ -404,11 +417,11 @@ let eprior.post[ig, ifil, iorder] = priori_err(device(θ_cnn_post[ig, ifil, iorder]))[1] end end - jldsave(joinpath(outdir, "eprior.jld2"); eprior...) + jldsave(joinpath(outdir_model, "eprior.jld2"); eprior...) end clean() -eprior = namedtupleload(joinpath(outdir, "eprior.jld2")) +eprior = namedtupleload(joinpath(outdir_model, "eprior.jld2")) ########################################################################## #src @@ -445,26 +458,23 @@ let dt = T(1e-3) ## No model - dudt_nomod = create_right_hand_side( + dudt_nomod = NS.create_right_hand_side( setup, psolver) err_post = create_loss_post_lux(dudt_nomod; sciml_solver = Tsit5(), dt = dt) epost.nomodel[I] = err_post(closure, θ_cnn_post[I].*0 , st, data)[1] # with closure - dudt = create_right_hand_side_with_closure( + dudt = NS.create_right_hand_side_with_closure( setup, psolver, closure, st) err_post = create_loss_post_lux(dudt; sciml_solver = Tsit5(), dt = dt) epost.cnn_prior[I] = err_post(closure, device(θ_cnn_prior[ig, ifil]), st, data)[1] epost.cnn_post[I] = err_post(closure, device(θ_cnn_post[I]), st, data)[1] clean() end - jldsave(joinpath(outdir, "epost.jld2"); epost...) + jldsave(joinpath(outdir_model, "epost.jld2"); epost...) end -epost = namedtupleload(joinpath(outdir, "epost.jld2")) +epost = namedtupleload(joinpath(outdir_model, "epost.jld2")) -epost.nomodel -epost.cnn_prior -epost.cnn_post ########################################################################## #src diff --git a/simulations/Benchmark/conf.yaml b/simulations/Benchmark/configs/conf.yaml similarity index 74% rename from simulations/Benchmark/conf.yaml rename to simulations/Benchmark/configs/conf.yaml index c11e726..4cb82cb 100644 --- a/simulations/Benchmark/conf.yaml +++ b/simulations/Benchmark/configs/conf.yaml @@ -24,7 +24,7 @@ seeds: prior: 345 post: 456 closure: - name: "CNN0" + name: "cnn_0" type: cnn radii: [2, 2, 2, 2, 2] channels: [24, 24, 24, 24, 2] @@ -33,6 +33,14 @@ closure: rng: "Xoshiro(seeds.θ_start)" priori: dotrain: true - nepoch: 100 + nepoch: 500 batchsize: 32 - opt: "OptimiserChain(Adam(T(1.0e-2)), ClipGrad(1))" \ No newline at end of file + opt: "OptimiserChain(Adam(T(1.0e-2)), ClipGrad(1))" +posteriori: + dotrain: true + projectorders: "(ProjectOrder.Last, )" + nepoch: 200 + opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(1))" + nunroll: 5 + nunroll_valid: 10 + dt: T(1e-3) diff --git a/simulations/Benchmark/configs/conf_2.yaml b/simulations/Benchmark/configs/conf_2.yaml new file mode 100644 index 0000000..19dccc4 --- /dev/null +++ b/simulations/Benchmark/configs/conf_2.yaml @@ -0,0 +1,48 @@ +docreatedata: false +docomp: true +ntrajectory: 8 +T: "Float32" +params: + D: 2 + lims: [0.0, 1.0] + Re: 6000.0 + tburn: 0.5 + tsim: 5.0 + savefreq: 10 + #ndns: 2048 + #nles: [128] + ndns: 256 + nles: [64] + filters: ["FaceAverage()"] + icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)" + method: "RKMethods.Wray3(; T)" + bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)" + issteadybodyforce: true + processors: "(; log = timelogger(; nupdate=100))" + Δt: 0.001 +seeds: + dns: 123 + θ_start: 234 + prior: 345 + post: 456 +closure: + name: "cnn_1" + type: cnn + radii: [2, 2, 2, 2, 2] + channels: [24, 24, 24, 24, 2] + activations: ["relu", "relu", "relu", "relu", "identity"] + use_bias: [true, true, true, true, false] + rng: "Xoshiro(seeds.θ_start)" +priori: + dotrain: true + nepoch: 500 + batchsize: 32 + opt: "OptimiserChain(Adam(T(1.0e-2)), ClipGrad(1))" +posteriori: + dotrain: true + projectorders: "(ProjectOrder.Last, )" + nepoch: 200 + opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(1))" + nunroll: 5 + nunroll_valid: 10 + dt: T(1e-3) diff --git a/simulations/Benchmark/job_a100.sh b/simulations/Benchmark/job_a100.sh index 4aa8055..51b968c 100644 --- a/simulations/Benchmark/job_a100.sh +++ b/simulations/Benchmark/job_a100.sh @@ -7,8 +7,9 @@ #SBATCH --partition=gpu_a100 #SBATCH --time=05:00:00 #SBATCH --mail-type=BEGIN,END -#SBATCH --mail-user=s.ciarella@esciencecenter.nl -#SBATCH --array=1-8 +# #SBATCH --mail-user=s.ciarella@esciencecenter.nl +#SBATCH --array=1-1 +# #SBATCH --array=1-8 # Note: # - gpu_a100: 18 cores @@ -20,7 +21,8 @@ mkdir -p /scratch-shared/$USER echo "Slurm job ID: $SLURM_JOB_ID" echo "Slurm array task ID: $SLURM_ARRAY_TASK_ID" -export JULIA_DEPOT_PATH=/scratch-shared/$USER/.julia_a100: +export JULIA_DEPOT_PATH=/scratch-shared/$USER/.julia_a100 +export CONF_FILE=$1 cd $HOME/CoupledNODE.jl/simulations/Benchmark diff --git a/simulations/Benchmark/src/Benchmark.jl b/simulations/Benchmark/src/Benchmark.jl index 3e0f786..7e4a1b7 100644 --- a/simulations/Benchmark/src/Benchmark.jl +++ b/simulations/Benchmark/src/Benchmark.jl @@ -8,7 +8,7 @@ using Adapt using ComponentArrays using CoupledNODE using CoupledNODE: loss_priori_lux, create_loss_post_lux -using CoupledNODE.NavierStokes: create_right_hand_side_with_closure +using CUDA using Dates using DifferentialEquations using DocStringExtensions diff --git a/simulations/Benchmark/src/train.jl b/simulations/Benchmark/src/train.jl index c3931f5..cf18c2e 100644 --- a/simulations/Benchmark/src/train.jl +++ b/simulations/Benchmark/src/train.jl @@ -20,6 +20,10 @@ createdata(; params, seeds, outdir, taskid) = ispath(datadir) || mkpath(datadir) push!(filenames, f) end + if isfile(filenames[1]) + @info "Data file $(filenames[1]) already exists. Skipping." + continue + end data = create_les_data(; params..., rng = Xoshiro(seed), filenames, Δt = params.Δt) @info("Trajectory info:", data[1].comptime/60, @@ -27,7 +31,7 @@ createdata(; params, seeds, outdir, taskid) = Base.summarysize(data)*1e-9,) end -function getpriorfile(outdir, nles, filter) +function getpriorfile(outdir, closure_name, nles, filter) joinpath( outdir, "priortraining", closure_name, splatfileparts(; filter, nles) * ".jld2") end @@ -53,7 +57,7 @@ function trainprior(; st, opt, batchsize, - loadcheckpoint = false, + loadcheckpoint = true, nepoch ) device(x) = adapt(params.backend, x) @@ -96,25 +100,40 @@ function trainprior(; data_i = namedtupleload(getdatafile(outdir, nles, Φ, s)) push!(data_valid, hcat(data_i)) end - io_train = CoupledNODE.NavierStokes.create_io_arrays_priori(data_train, setup) - io_valid = CoupledNODE.NavierStokes.create_io_arrays_priori(data_valid, setup) + NS = Base.get_extension(CoupledNODE, :NavierStokes) + io_train = NS.create_io_arrays_priori(data_train, setup) + io_valid = NS.create_io_arrays_priori(data_valid, setup) θ = device(copy(θ_start)) - dataloader_prior = CoupledNODE.NavierStokes.create_dataloader_prior( + dataloader_prior = NS.create_dataloader_prior( io_train[itotal]; batchsize = batchsize, - rng = Random.Xoshiro(dns_seeds_train[itotal])) + rng = Random.Xoshiro(dns_seeds_train[itotal]), device = device) train_data_priori = dataloader_prior() loss_priori_lux(closure, θ, st, train_data_priori) loss = loss_priori_lux - callbackstate, callback = CoupledNODE.create_callback( - closure, θ, io_valid[itotal], loss, st, batch_size = batchsize, - rng = Xoshiro(batchseed), do_plot = true, plot_train = true) + if loadcheckpoint && isfile(checkfile) + callbackstate, trainstate, epochs_trained = CoupledNODE.load_checkpoint(checkfile) + nepochs_left = nepoch - epochs_trained + else + callbackstate = trainstate = nothing + nepochs_left = nepoch + end - l, trainstate = CoupledNODE.train( - closure, θ, st, dataloader_prior, loss; nepochs = nepoch, - alg = opt, cpu = params.backend == CPU(), callback = callback) - # TODO CoupledNODE has no checkpoints yet, but here it should save them - # TODO CoupledNODE should also save some figures + callbackstate, callback = NS.create_callback( + closure, θ, io_valid[itotal], loss, st; + callbackstate = callbackstate, batch_size = batchsize, + rng = Xoshiro(batchseed), do_plot = true, plot_train = true, figfile = figfile, device = device) + + if nepochs_left <= 0 + @info "No epochs left to train." + continue + else + l, trainstate = CoupledNODE.train( + closure, θ, st, dataloader_prior, loss; tstate = trainstate, + nepochs = nepochs_left, + alg = opt, cpu = params.backend == CPU(), callback = callback) + end + save_object(checkfile, (callbackstate = callbackstate, trainstate = trainstate)) θ = callbackstate.θmin # Use best θ instead of last θ results = (; θ = Array(θ), comptime = time() - starttime, @@ -147,13 +166,14 @@ function trainpost(; nunroll, closure, θ_start, + loadcheckpoint = true, st, opt, nunroll_valid, nepoch, dt ) - device(x) = adapt(params.backend, x) + device(x) = CUDA.functional ? adapt(params.backend, x) : x itotal = 0 for projectorder in projectorders, (ifil, Φ) in enumerate(params.filters), @@ -195,28 +215,40 @@ function trainpost(; data_i = namedtupleload(getdatafile(outdir, nles, Φ, s)) push!(data_valid, hcat(data_i)) end - io_train = CoupledNODE.NavierStokes.create_io_arrays_posteriori(data_train, setup) - io_valid = CoupledNODE.NavierStokes.create_io_arrays_posteriori(data_valid, setup) - #θ = copy(θ_start) + NS = Base.get_extension(CoupledNODE, :NavierStokes) + io_train = NS.create_io_arrays_posteriori(data_train, setup) + io_valid = NS.create_io_arrays_posteriori(data_valid, setup) θ = device(copy(θ_start[itotal])) - dataloader_post = CoupledNODE.NavierStokes.create_dataloader_posteriori( + dataloader_post = NS.create_dataloader_posteriori( io_train[itotal]; nunroll = nunroll, - rng = Random.Xoshiro(dns_seeds_train[itotal])) + rng = Random.Xoshiro(dns_seeds_train[itotal]), device = device) - dudt_nn = create_right_hand_side_with_closure( + dudt_nn = NS.create_right_hand_side_with_closure( setup[1], psolver, closure, st) loss = create_loss_post_lux(dudt_nn; sciml_solver = Tsit5(), dt = dt) - callbackstate, callback = CoupledNODE.create_callback( - closure, θ, io_valid[itotal], loss, st, nunroll = nunroll_valid, - rng = Xoshiro(postseed), do_plot = true, plot_train = true) + if loadcheckpoint && isfile(checkfile) + callbackstate, trainstate, epochs_trained = CoupledNODE.load_checkpoint(checkfile) + nepochs_left = nepoch - epochs_trained + else + callbackstate = trainstate = nothing + nepochs_left = nepoch + end - l, trainstate = CoupledNODE.train( - closure, θ, st, dataloader_post, loss; nepochs = nepoch, - alg = opt, cpu = params.backend == CPU(), callback = callback) - # TODO CoupledNODE has no checkpoints yet, but here it should save them - # TODO CoupledNODE should also save some figures + callbackstate, callback = NS.create_callback( + closure, θ, io_valid[itotal], loss, st; + callbackstate = callbackstate, nunroll = nunroll_valid, + rng = Xoshiro(postseed), do_plot = true, plot_train = true, figfile = figfile, device = device) + if nepochs_left <= 0 + @info "No epochs left to train." + continue + else + l, trainstate = CoupledNODE.train( + closure, θ, st, dataloader_post, loss; tstate = trainstate, nepochs = nepochs_left, + alg = opt, cpu = params.backend == CPU(), callback = callback) + end + save_object(checkfile, (callbackstate = callbackstate, trainstate = trainstate)) θ = callbackstate.θmin # Use best θ instead of last θ results = (; θ = Array(θ), comptime = time() - starttime, diff --git a/src/CoupledNODE.jl b/src/CoupledNODE.jl index 30447ea..ec87d42 100644 --- a/src/CoupledNODE.jl +++ b/src/CoupledNODE.jl @@ -7,6 +7,7 @@ include("models/cnn.jl") include("loss/loss_priori.jl") include("loss/loss_posteriori.jl") +include("checkpoints.jl") include("train.jl") end # module CoupledNODE diff --git a/src/checkpoints.jl b/src/checkpoints.jl new file mode 100644 index 0000000..5d4493d --- /dev/null +++ b/src/checkpoints.jl @@ -0,0 +1,28 @@ +using JLD2 + +""" + load_checkpoint(checkfile) + +Load a training checkpoint from the specified file. + +# Arguments +- `checkfile::String`: The path to the checkpoint file. + +# Returns +- `callbackstate`: The state of the callback at the checkpoint. +- `trainstate`: The state of the training process at the checkpoint. +- `epochs_trained::Int`: The number of epochs completed at the checkpoint. + +# Example +```julia +callbackstate, trainstate, epochs_trained = load_checkpoint("checkpoint.jld2") +``` +""" +function load_checkpoint(checkfile) + checkpoint = load_object(checkfile) + callbackstate = checkpoint.callbackstate + trainstate = checkpoint.trainstate + epochs_trained = length(callbackstate.lhist_train) + @info "Loading checkpoint from $checkfile.\nPrevious training reached epoch $(epochs_trained)." + return callbackstate, trainstate, epochs_trained +end diff --git a/src/loss/loss_posteriori.jl b/src/loss/loss_posteriori.jl index e267949..b0614af 100644 --- a/src/loss/loss_posteriori.jl +++ b/src/loss/loss_posteriori.jl @@ -157,33 +157,25 @@ normalized by the sum of squared actual data values. This makes it compatible with the Lux ecosystem. """ function create_loss_post_lux(rhs; sciml_solver = Tsit5(), cpu::Bool = false, kwargs...) - dev = cpu ? Lux.cpu_device() : Lux.gpu_device() - ext = Base.get_extension(@__MODULE__, :CoupledNODECUDA) - if !isnothing(ext) - ArrayType = cpu ? Array : CUDA.CuArray + Cuda_ext = Base.get_extension(CoupledNODE, :CoupledNODECUDA) + if !isnothing(Cuda_ext) + ArrayType = Cuda_ext.ArrayType() + dev = cpu ? Lux.cpu_device() : Lux.gpu_device() else ArrayType = Array + dev = Lux.cpu_device() end function loss_function(model, ps, st, (u, t)) griddims = Zygote.@ignore ((:) for _ in 1:(ndims(u) - 2)) x = dev(u[griddims..., :, 1]) y = dev(u[griddims..., :, 2:end]) # remember to discard sol at the initial time step tspan, dt, prob, pred = nothing, nothing, nothing, nothing # initialize variable outside allowscalar do. - if !isnothing(ext) - CUDA.allowscalar() do - if !(:dt in keys(kwargs)) - dt = t[2] - t[1] - kwargs = (; kwargs..., dt = dt) - end - tspan = [t[1], t[end]] - end - else - if !(:dt in keys(kwargs)) - dt = t[2] - t[1] - kwargs = (; kwargs..., dt = dt) - end - tspan = [t[1], t[end]] + if !(:dt in keys(kwargs)) + dt = @views t[2:2] .- t[1:1] + dt = only(Array(dt)) + kwargs = (; kwargs..., dt = dt) end + tspan = @views [t[1:1]; t[end:end]] prob = ODEProblem(rhs, x, tspan, ps) pred = ArrayType(solve( prob, sciml_solver; u0 = x, p = ps, adaptive = false, saveat = t, kwargs...)) diff --git a/src/train.jl b/src/train.jl index 1d99afb..0d00949 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,20 +1,28 @@ using SciMLSensitivity using Lux: Lux -using Juno: Juno using Zygote: Zygote using Optimization: Optimization using OptimizationOptimisers: OptimizationOptimisers +using CairoMakie: save function train(model, ps, st, train_dataloader, loss_function; - nepochs = 100, ad_type = Optimization.AutoZygote(), - alg = OptimizationOptimisers.Adam(0.1), cpu::Bool = false, kwargs...) + nepochs = 100, + ad_type = Optimization.AutoZygote(), + alg = OptimizationOptimisers.Adam(0.1), + cpu::Bool = false, + kwargs...) dev = cpu ? Lux.cpu_device() : Lux.gpu_device() ps, st = (ps, st) .|> dev # Retrieve the callback from kwargs, default to `nothing` if not provided callback = get(kwargs, :callback, nothing) - tstate = Lux.Training.TrainState(model, ps, st, alg) + # Retrieve the training state from kwargs, otherwise create a new one + tstate = get(kwargs, :tstate, nothing) + if tstate === nothing + tstate = Lux.Training.TrainState(model, ps, st, alg) + end loss::Float32 = 0 #NOP TODO: check compatibiity with precision of data - Juno.@progress for epoch in 1:nepochs + @info "Lux Training started" + for epoch in 1:nepochs data = Zygote.@ignore dev(train_dataloader()) _, loss, _, tstate = Lux.Training.single_train_step!( ad_type, loss_function, data, tstate) diff --git a/test/test_CUDA.jl b/test/test_CUDA.jl new file mode 100644 index 0000000..42d59ac --- /dev/null +++ b/test/test_CUDA.jl @@ -0,0 +1,13 @@ +using Test +using CoupledNODE + +@testset "CUDA" begin + using Pkg + Pkg.add("CUDA") + using CUDA + Cuda_ext = Base.get_extension(CoupledNODE, :CoupledNODECUDA) + ArrayType = Cuda_ext.ArrayType() + @test ArrayType == CUDA.CuArray || ArrayType == Array + @test Cuda_ext.allowscalar(false) == nothing + @test Cuda_ext.allowscalar(true) == nothing +end diff --git a/test/test_io.jl b/test/test_io.jl index 8608bc8..6f420b3 100644 --- a/test/test_io.jl +++ b/test/test_io.jl @@ -1,7 +1,7 @@ using CoupledNODE -NS = Base.get_extension(CoupledNODE, :NavierStokes) using IncompressibleNavierStokes using NeuralClosure +NS = Base.get_extension(CoupledNODE, :NavierStokes) using Random @testset "Read YAML" begin @@ -46,10 +46,23 @@ using Random @test params.issteadybodyforce == ref_params.issteadybodyforce @test params.processors == ref_params.processors @test params.Δt == ref_params.Δt - # TODO: I do not know how to test those 3 - #@test params.icfunc == ref_params.icfunc + # test icfunc + setups = map(params.nles) do nles + x = ntuple(α -> LinRange(T(0.0), T(1.0), nles + 1), params.D) + Setup(; x = x, Re = params.Re) + end + setup = setups[1] + psolver = psolver_spectral(setup) + @test params.icfunc(setup, psolver, Xoshiro(123)) == + ref_params.icfunc(setup, psolver, Xoshiro(123)) + # test bodyforce + x = rand(1)[1] + y = rand(1)[1] + t = 0.0 + dim = 1 + @test params.bodyforce(dim, x, y, t) == ref_params.bodyforce(dim, x, y, t) + # TODO: test method #@test params.method == ref_params.method - #@test params.bodyforce == ref_params.bodyforce # test seeds ref_seeds = (;