-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpointer revival #326
Checkpointer revival #326
Changes from 2 commits
1212c82
d434e55
40fd4b0
b47fda8
f7476b4
171f582
14c221b
823d310
7793b1e
02cb710
ddbc1d6
8aeacbd
0f933c8
ebb0f34
6935cf6
c96175b
cbd4a8d
64db4c7
c2c01ee
3552f77
89518cf
cffc11d
bcfaff8
ea91e5f
517b32b
898216b
00b6eba
4e72d35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -324,3 +324,72 @@ function time_to_write(clock, out::JLD2OutputWriter) | |
return false | ||
end | ||
end | ||
|
||
mutable struct Checkpointer <: OutputWriter | ||
dir :: String | ||
prefix :: String | ||
output_frequency :: Int | ||
end | ||
|
||
function Checkpointer(; output_frequency, dir=".", prefix="checkpoint", force=false) | ||
mkpath(dir) | ||
return Checkpointer(dir, prefix, output_frequency) | ||
end | ||
|
||
function savesubfields!(file, model, name, flds=propertynames(getproperty(model, name))) | ||
for f in flds | ||
file["$name/$f"] = Array(getproperty(getproperty(model, name), f).data.parent) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds like a good check. Won't be needed as long we maintain There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that we should allow users to supply the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Continuing from my comment below, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ali-ramadhan was this resolved? |
||
end | ||
return nothing | ||
end | ||
|
||
checkpointed_structs = [:arch, :boundary_conditions, :grid, :clock, :eos, :constants, :closure] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These fields could be made properties of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ali-ramadhan was this resolved? |
||
checkpointed_fieldsets = [:velocities, :tracers, :G, :Gp] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True. This was before PR #325 was merged and included the tendencies with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ali-ramadhan was this resolved? |
||
|
||
function write_output(model, c::Checkpointer) | ||
@warn "Checkpointer will not save forcing functions, output writers, or diagnostics. They will need to be " * | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to have a warning that is always printed? Perhaps is it simply better to document this aspect of the checkpointer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm, might then be good to just print the warning if there is a non-zero forcing function, or an output writer or diagnostic included. Then yeah warning may be removed if the checkpointer is well-documented. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is what I propose:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ali-ramadhan was this resolved? |
||
"restored manually." | ||
|
||
filepath = joinpath(c.dir, c.prefix * string(model.clock.iteration) * ".jld2") | ||
|
||
jldopen(filepath, "w") do file | ||
# Checkpointing model properties that we can just serialize. | ||
[file["$p"] = getproperty(model, p) for p in checkpointed_structs] | ||
|
||
# Checkpointing structs containing fields. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps this step can be combined with the first, along with an if statement that inspects the content of the struct to be saved and determines whether it contains |
||
[savesubfields!(file, model, p) for p in checkpointed_fieldsets] | ||
end | ||
end | ||
|
||
_arr(::CPU, a) = a | ||
_arr(::GPU, a) = CuArray(a) | ||
|
||
function restore_from_checkpoint(filepath) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function must have a |
||
@warn "Checkpointer cannot restore forcing functions, output writers, or diagnostics. They will need to be " * | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above. What we need is documentation about proper checkpointing, which will include information about how to restore a model from a checkpoint that includes functions. In fact, we could even provide features in the checkpointer that streamline this process (by indicating parts of the model structure that are associated with functions, and asking the user to provide those functions during checkpoint restoration). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah warning doesn't have be in |
||
"restored manually." | ||
|
||
kwargs = Dict{Symbol, Any}() # We'll store all the kwargs we need to initialize a Model. | ||
|
||
file = jldopen(filepath, "r") | ||
|
||
# Restore model properties that were just serialized. | ||
for p in checkpointed_structs | ||
kwargs[Symbol(p)] = file["$p"] | ||
end | ||
|
||
# The Model constructor needs N and L. | ||
kwargs[:N] = (kwargs[:grid].Nx, kwargs[:grid].Ny, kwargs[:grid].Nz) | ||
kwargs[:L] = (kwargs[:grid].Lx, kwargs[:grid].Ly, kwargs[:grid].Lz) | ||
|
||
model = Model(; kwargs...) | ||
|
||
for p in checkpointed_fieldsets | ||
for subp in propertynames(getproperty(model, p)) | ||
getproperty(getproperty(model, p), subp).data.parent .= _arr(model.arch, file["$p/$subp"]) | ||
end | ||
end | ||
|
||
close(file) | ||
|
||
return model | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,11 +35,65 @@ function run_thermal_bubble_netcdf_tests() | |
@test all(S .≈ data(model.tracers.S)) | ||
end | ||
|
||
""" | ||
Run two coarse rising thermal bubble simulations and make sure that when | ||
restarting from a checkpoint, the restarted simulation matches the non-restarted | ||
simulation numerically. | ||
""" | ||
function run_thermal_bubble_checkpointer_tests(arch) | ||
Nx, Ny, Nz = 16, 16, 16 | ||
Lx, Ly, Lz = 100, 100, 100 | ||
Δt = 6 | ||
|
||
true_model = Model(N=(Nx, Ny, Nz), L=(Lx, Ly, Lz), ν=4e-2, κ=4e-2, arch=arch) | ||
|
||
# Add a cube-shaped warm temperature anomaly that takes up the middle 50% | ||
# of the domain volume. | ||
i1, i2 = round(Int, Nx/4), round(Int, 3Nx/4) | ||
j1, j2 = round(Int, Ny/4), round(Int, 3Ny/4) | ||
k1, k2 = round(Int, Nz/4), round(Int, 3Nz/4) | ||
true_model.tracers.T.data[i1:i2, j1:j2, k1:k2] .+= 0.01 | ||
|
||
checkpointed_model = deepcopy(true_model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it necessary to copy the model before checkpointing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's creating a second model that will be checkpointed as opposed to Probably over paranoid but I wanted the two models to be time-stepped separately. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you worried that checkpointing will modify Since this behavior is undesirable / unexpected, perhaps it should simply be tested for, so that we can assume that the model is not modified. |
||
|
||
time_step!(true_model, 9, Δt) | ||
|
||
checkpointer = Checkpointer(output_frequency=5) | ||
push!(checkpointed_model.output_writers, checkpointer) | ||
|
||
# Checkpoint should be saved as "test_model_checkpoint_5.jld" after the 5th iteration. | ||
time_step!(checkpointed_model, 5, Δt) | ||
|
||
# Remove all knowledge of the checkpointed model. | ||
checkpointed_model = nothing | ||
|
||
restored_model = restore_from_checkpoint("checkpoint5.jld2") | ||
|
||
time_step!(restored_model, 4, Δt; adams_bashforth_parameter = n->0.125) | ||
|
||
# Now the true_model and restored_model should be identical. | ||
@test all(restored_model.velocities.u.data .≈ true_model.velocities.u.data) | ||
@test all(restored_model.velocities.v.data .≈ true_model.velocities.v.data) | ||
@test all(restored_model.velocities.w.data .≈ true_model.velocities.w.data) | ||
@test all(restored_model.tracers.T.data .≈ true_model.tracers.T.data) | ||
@test all(restored_model.tracers.S.data .≈ true_model.tracers.S.data) | ||
@test all(restored_model.G.Gu.data .≈ true_model.G.Gu.data) | ||
@test all(restored_model.G.Gv.data .≈ true_model.G.Gv.data) | ||
@test all(restored_model.G.Gw.data .≈ true_model.G.Gw.data) | ||
@test all(restored_model.G.GT.data .≈ true_model.G.GT.data) | ||
@test all(restored_model.G.GS.data .≈ true_model.G.GS.data) | ||
end | ||
|
||
@testset "Output writers" begin | ||
println("Testing output writers...") | ||
|
||
@testset "NetCDF" begin | ||
println(" Testing NetCDF output writer...") | ||
run_thermal_bubble_netcdf_tests() | ||
end | ||
|
||
@testset "Checkpointer" begin | ||
println(" Testing Checkpointer...") | ||
run_thermal_bubble_checkpointer_tests(CPU()) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kwarg 'force' is not used here. It should either be used, or removed from the function signature.
In the
JLD2OutputWriter
, the kwargforce
indicates whether file creation should be 'forced' (it corresponds to the same keyword passed tomkpath
.