Skip to content

Commit

Permalink
fix device
Browse files Browse the repository at this point in the history
  • Loading branch information
SCiarella committed Dec 20, 2024
1 parent e6ca6e8 commit eea18be
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand All @@ -38,6 +37,7 @@ LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"

[sources]
NeuralClosure = {rev = "main", url = "https://github.com/DEEPDIP-project/NeuralClosure.jl.git"}
Expand Down
4 changes: 1 addition & 3 deletions simulations/Benchmark/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using IncompressibleNavierStokes
using NeuralClosure
using CoupledNODE
NS = Base.get_extension(CoupledNODE, :NavierStokes)
conf =nothing
global conf
try
conf = NS.read_config(ENV["CONF_FILE"])
@info "Reading configuration file from ENV"
Expand Down Expand Up @@ -222,8 +222,6 @@ end
# Save parameters to disk after each run.
# Plot training progress (for a validation data batch).

# Parameter save files

# Train
let
dotrain = conf["priori"]["dotrain"]
Expand Down
6 changes: 2 additions & 4 deletions simulations/Benchmark/configs/conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ params:
tburn: 0.5
tsim: 5.0
savefreq: 10
#ndns: 2048
#nles: [128]
ndns: 256
nles: [64]
ndns: 2048
nles: [128]
filters: ["FaceAverage()"]
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
method: "RKMethods.Wray3(; T)"
Expand Down
2 changes: 0 additions & 2 deletions simulations/Benchmark/job_a100.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# - gpu_h100: 16 cores
# https://servicedesk.surf.nl/wiki/display/WIKI/Snellius+partitions+and+accounting

nvidia-smi

mkdir -p /scratch-shared/$USER

echo "Slurm job ID: $SLURM_JOB_ID"
Expand Down
1 change: 1 addition & 0 deletions simulations/Benchmark/src/Benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Adapt
using ComponentArrays
using CoupledNODE
using CoupledNODE: loss_priori_lux, create_loss_post_lux
using CUDA
using Dates
using DifferentialEquations
using DocStringExtensions
Expand Down
6 changes: 3 additions & 3 deletions simulations/Benchmark/src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function trainprior(;
callbackstate, callback = NS.create_callback(
closure, θ, io_valid[itotal], loss, st;
callbackstate = callbackstate, batch_size = batchsize,
rng = Xoshiro(batchseed), do_plot = true, plot_train = true, figfile = figfile)
rng = Xoshiro(batchseed), do_plot = true, plot_train = true, figfile = figfile, device = device)

if nepochs_left <= 0
@info "No epochs left to train."
Expand Down Expand Up @@ -173,7 +173,7 @@ function trainpost(;
nepoch,
dt
)
device(x) = adapt(params.backend, x)
device(x) = CUDA.functional ? adapt(params.backend, x) : x
itotal = 0
for projectorder in projectorders,
(ifil, Φ) in enumerate(params.filters),
Expand Down Expand Up @@ -239,7 +239,7 @@ function trainpost(;
callbackstate, callback = NS.create_callback(
closure, θ, io_valid[itotal], loss, st;
callbackstate = callbackstate, nunroll = nunroll_valid,
rng = Xoshiro(postseed), do_plot = true, plot_train = true, figfile = figfile)
rng = Xoshiro(postseed), do_plot = true, plot_train = true, figfile = figfile, device = device)
if nepochs_left <= 0
@info "No epochs left to train."
continue
Expand Down
4 changes: 2 additions & 2 deletions src/train.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using SciMLSensitivity
using Lux: Lux
using Juno: Juno
using Zygote: Zygote
using Optimization: Optimization
using OptimizationOptimisers: OptimizationOptimisers
Expand All @@ -22,7 +21,8 @@ function train(model, ps, st, train_dataloader, loss_function;
tstate = Lux.Training.TrainState(model, ps, st, alg)
end
loss::Float32 = 0 #NOP TODO: check compatibiity with precision of data
Juno.@progress for epoch in 1:nepochs
@info "Lux Training started"
for epoch in 1:nepochs
data = Zygote.@ignore dev(train_dataloader())
_, loss, _, tstate = Lux.Training.single_train_step!(
ad_type, loss_function, data, tstate)
Expand Down

0 comments on commit eea18be

Please sign in to comment.