Skip to content

Commit

Permalink
chore: update
Browse files Browse the repository at this point in the history
  • Loading branch information
agdestein committed Sep 23, 2024
1 parent 4c37146 commit a5ac4bb
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 82 deletions.
6 changes: 2 additions & 4 deletions lib/NeuralClosure/src/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ function create_loss_post(;
psolver,
closure,
nupdate = 1,
projectorder = :last,
)
closure_model = wrappedclosure(closure, setup)
setup = (; setup..., closure_model, projectorder)
setup = (; setup..., closure_model)
(; dimension, Iu) = setup.grid
D = dimension()
function loss_post(data, θ)
Expand Down Expand Up @@ -119,9 +118,8 @@ function create_relerr_post(;
psolver,
closure_model,
nupdate = 1,
projectorder = :last,
)
setup = (; setup..., closure_model, projectorder)
setup = (; setup..., closure_model)
(; dimension, Iu) = setup.grid
D = dimension()
(; u, t) = data
Expand Down
177 changes: 99 additions & 78 deletions lib/PaperDC/les3D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,7 @@ end

# Load filtered DNS data
data = load.(filenames, "data");
@info(
"Data: ",
Base.summarysize(data) * 1e-9,
length.(getfield.(data, :t)),
)
@info("Data: ", Base.summarysize(data) * 1e-9, length.(getfield.(data, :t)),)

sum(d -> d.comptime, data) / 3600

Expand Down Expand Up @@ -240,20 +236,22 @@ end
# CNN architecture
closure, θ₀ = cnn(;
setup = setups[1],
radii = [2, 2, 2, 2, 2],
channels = [24, 24, 24, 24, params.D],
activations = [tanh, tanh, tanh, tanh, identity],
use_bias = [true, true, true, true, false],
radii = [2, 2, 2, 2],
channels = [24, 24, 24, params.D],
activations = [tanh, tanh, tanh, identity],
use_bias = [true, true, true, false],
rng = Xoshiro(seeds.θ₀),
);
closure.chain

@info "Initialized CNN with $(length(θ₀)) parameters"

# Give the CNN a test run
# Note: Data and parameters are stored on the CPU, and
# must be moved to the GPU before use (with `gpu_device`)
let
using NeuralClosure.Zygote
u = io_train[1, 1].u[:, :, :, :, 1:50] |> gpu_device()
u = io_train[1, 1].u[:, :, :, :, 1:10] |> gpu_device()
θ = θ₀ |> gpu_device()
closure(u, θ)
gradient-> sum(closure(u, θ)), θ)
Expand All @@ -278,7 +276,8 @@ priorfiles = map(CartesianIndices(io_train)) do I
end

# Train
let
trainprior = false
trainprior && let
I = CartesianIndices(io_train)
itask = parse(Int, ENV["SLURM_ARRAY_TASK_ID"])
# ig, ifil = I[itask].I
Expand All @@ -289,7 +288,7 @@ let
starttime = time()
@info "Training a-priori for ig = $ig, ifil = $ifil"
trainseed, validseed = splitseed(seeds.prior, 2) # Same seed for all training setups
dataloader = create_dataloader_prior(io_train[ig, ifil]; batchsize = 50, device)
dataloader = create_dataloader_prior(io_train[ig, ifil]; batchsize = 20, device)
θ = T(1.0) * device(θ₀)
loss = create_loss_prior(mean_squared_error, closure)
opt = Adam(T(1.0e-3))
Expand All @@ -316,8 +315,14 @@ let
@reset callbackstate.θmin = callbackstate.θmin |> gpu_device()
end
for icheck = ncheck+1:10
(; trainstate, callbackstate) =
train(; dataloader, loss, trainstate, callbackstate, callback, niter = 1_000)
(; trainstate, callbackstate) = train(;
dataloader,
loss,
trainstate,
callbackstate,
callback,
niter = 1_000,
)
# Save all states to resume training later
# First move all arrays to CPU
c = callbackstate |> cpu_device()
Expand All @@ -327,25 +332,22 @@ let
θ = callbackstate.θmin # Use best θ instead of last θ
prior = (; θ = Array(θ), comptime = time() - starttime, callbackstate.hist)
jldsave(priorfiles[ig, ifil]; prior)
# end
clean()
end

exit()

# Load learned parameters and training times
prior = load.(priorfiles, "prior")
θ_cnn_prior = [copyto!(device(θ₀), p.θ) for p in prior];

# Check that parameters are within reasonable bounds
θ_cnn_prior .|> extrema

# Training times
map(p -> p.comptime, prior)
map(p -> p.comptime, prior) |> vec
map(p -> p.comptime, prior) |> sum # Seconds
map(p -> p.comptime, prior) |> sum |> x -> x / 60 # Minutes
map(p -> p.comptime, prior) |> sum |> x -> x / 3600 # Hours
# # Load learned parameters and training times
# prior = load.(priorfiles, "prior")
# θ_cnn_prior = [copyto!(device(θ₀), p.θ) for p in prior];
#
# # Check that parameters are within reasonable bounds
# θ_cnn_prior .|> extrema
#
# # Training times
# map(p -> p.comptime, prior)
# map(p -> p.comptime, prior) |> vec
# map(p -> p.comptime, prior) |> sum # Seconds
# map(p -> p.comptime, prior) |> sum |> x -> x / 60 # Minutes
# map(p -> p.comptime, prior) |> sum |> x -> x / 3600 # Hours

########################################################################## #src

Expand All @@ -359,66 +361,85 @@ map(p -> p.comptime, prior) |> sum |> x -> x / 3600 # Hours
#
# The time stepper `RKProject` allows for choosing when to project.

I_post = CartesianIndices((length(params.filters), length(params.nles), 2))

# Parameter save files
postfiles = map(CartesianIndices((size(io_train)..., 2))) do I
postfiles = map(I_post) do I
ig, ifil, iorder = I.I
"$outdir/post_iorder$(iorder)_ifil$(ifil)_ig$(ig).jld2"
end

# Train
let
ngrid, nfilter = size(io_train)
for iorder = 1:2, ifil = 1:nfilter, ig = 1:ngrid
clean()
starttime = time()
@info "Training a-posteriori for iorder = $iorder, ifil = $ifil, ig = $ig"
projectorder = ProjectOrder.T(iorder)
rng = Xoshiro(seeds.post) # Same seed for all training setups
setup = setups[ig]
psolver = psolver_spectral(setup)
loss = create_loss_post(;
trainpost = true
trainpost && let
itask = parse(Int, ENV["SLURM_ARRAY_TASK_ID"])
# ig, ifil, iorder = I_post[itask].I
ig, ifil, iorder = 2, 2, 2
# ngrid, nfilter = size(io_train)
# for iorder = 1:2, ifil = 1:nfilter, ig = 1:ngrid
clean()
starttime = time()
@info "Training a-posteriori for iorder = $iorder, ifil = $ifil, ig = $ig"
projectorder = ProjectOrder.T(iorder)
setup = setups[ig]
psolver = psolver_spectral(setup)
loss = create_loss_post(;
setup,
psolver,
method = RKProject(params.method, projectorder),
closure,
nupdate = 2, # Time steps per loss evaluation
)
dataloader = create_dataloader_post(map(d -> (; u = d.data[ig, ifil].u, d.t) , data_train); device, nunroll = 5)
# θ = copy(θ_cnn_prior[ig, ifil])
θ = device(θ₀)
opt = Adam(T(1.0e-3))
optstate = Optimisers.setup(opt, θ)
it = 1:20
traj = data_valid[1]
traj = (; u = device.(traj.data[ig, ifil].u[it]), t = traj.t[it])
@info "Validating on times $(traj.t[it])"
(; callbackstate, callback) = create_callback(
create_relerr_post(;
data = traj,
setup,
psolver,
method = RKProject(RK44(; T), projectorder),
closure,
nupdate = 2, # Time steps per loss evaluation
)
data = [(; u = d.data[ig, ifil].u, d.t) for d in data_train]
dataloader = create_dataloader_post(data; device, nunroll = 20)
θ = copy(θ_cnn_prior[ig, ifil])
opt = Adam(T(1.0e-3))
optstate = Optimisers.setup(opt, θ)
it = 1:30
data = data_valid[1]
data = (; u = device.(data.data[ig, ifil].u[it]), t = data.t[it])
(; callbackstate, callback) = create_callback(
create_relerr_post(;
data,
setup,
psolver,
method = RKProject(RK44(; T), projectorder),
closure_model = wrappedclosure(closure, setup),
nupdate = 2,
);
θ,
displayref = false,
nupdate = 10,
)
(; trainstate, callbackstate) = train(;
dataloader,
loss,
trainstate = (; optstate, θ, rng),
niter = 2000,
callbackstate,
callback,
)
θ = callbackstate.θmin # Use best θ instead of last θ
post = (; θ = Array(θ), comptime = time() - starttime)
jldsave(postfiles[iorder, ifil, ig]; post)
method = RKProject(params.method, projectorder),
closure_model = wrappedclosure(closure, setup),
nupdate = 2,
);
θ,
displayref = false,
nupdate = 5,
)
trainstate = (; optstate, θ, rng = Xoshiro(seeds.post))
base, ext = splitext(postfiles[ig, ifil, iorder])
checkpointname = "$(base)_checkpoint.jld2"
ncheck = 0
if false
@info "Resuming from checkpoint $checkpointname"
ncheck, trainstate, callbackstate =
load(checkpointname, "ncheck", "trainstate", "callbackstate")
trainstate = trainstate |> gpu_device()
@reset callbackstate.θmin = callbackstate.θmin |> gpu_device()
end
for icheck = ncheck+1:10
(; trainstate, callbackstate) =
train(; dataloader, loss, trainstate, niter = 200, callbackstate, callback)
@info "Saving checkpoint to $(basename(checkpointname))..."
c = callbackstate |> cpu_device()
t = trainstate |> cpu_device()
jldsave(checkpointname; ncheck = icheck, callbackstate = c, trainstate = t)
@info "... done"
end
θ = callbackstate.θmin # Use best θ instead of last θ
post = (; θ = Array(θ), comptime = time() - starttime)
jldsave(postfiles[iorder, ifil, ig]; post)
clean()
end

exit()

# Load learned parameters and training times
post = load.(postfiles, "post");
θ_cnn_post = [copyto!(device(θ₀), p.θ) for p in post];
Expand Down

0 comments on commit a5ac4bb

Please sign in to comment.