Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointer revival #326

Merged
merged 28 commits into from
Aug 28, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1212c82
Checkpointer that saves the bare essentials using JLD2.
ali-ramadhan Aug 2, 2019
d434e55
Revive checkpointer test
ali-ramadhan Aug 2, 2019
40fd4b0
Merge branch 'master' into checkpointer
ali-ramadhan Aug 22, 2019
b47fda8
Do not save forcing functions.
ali-ramadhan Aug 22, 2019
f7476b4
Move output utils to the bottom.
ali-ramadhan Aug 22, 2019
171f582
Validate frequency and interval.
ali-ramadhan Aug 22, 2019
14c221b
Remove redundant warnings.
ali-ramadhan Aug 22, 2019
823d310
User can select kwargs to pass to restored model (e.g. for forcing).
ali-ramadhan Aug 22, 2019
7793b1e
Clean up checkpointer utils.
ali-ramadhan Aug 22, 2019
02cb710
Update checkpointer test.
ali-ramadhan Aug 22, 2019
ddbc1d6
Forgot to test checkpointer on GPU.
ali-ramadhan Aug 22, 2019
8aeacbd
Switch to saveproperty! for JLD2OutputWriter.
ali-ramadhan Aug 25, 2019
0f933c8
Analogous serializeproperty! methods for checkpointing.
ali-ramadhan Aug 25, 2019
ebb0f34
Saving/serializing boundary conditions.
ali-ramadhan Aug 25, 2019
6935cf6
Refactor Checkpointer.
ali-ramadhan Aug 25, 2019
c96175b
Properly restore fields and boundary conditions.
ali-ramadhan Aug 25, 2019
cbd4a8d
Index named tuples by name, rather than number/range.
ali-ramadhan Aug 26, 2019
64db4c7
hasfunction utility.
ali-ramadhan Aug 26, 2019
c2c01ee
Revert "Index named tuples by name, rather than number/range."
ali-ramadhan Aug 26, 2019
3552f77
Checkpointer will serialize bcs as a whole.
ali-ramadhan Aug 26, 2019
89518cf
Cleanup, comments, and update test.
ali-ramadhan Aug 26, 2019
cffc11d
Fix typo.
ali-ramadhan Aug 26, 2019
bcfaff8
Nuke HorizontalAverages and VerticalPlanes scratch structs.
ali-ramadhan Aug 26, 2019
ea91e5f
Simpler validate interval and frequency function
ali-ramadhan Aug 27, 2019
517b32b
Clean up checkpointer constructor
ali-ramadhan Aug 27, 2019
898216b
adds array_refs property to checkpointer and generalizes hasfunction
glwagner Aug 27, 2019
00b6eba
gets has_reference function working with a little elbow grease
glwagner Aug 27, 2019
4e72d35
moves using Distributed to Oceananigans.jl from output_writers
glwagner Aug 27, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Oceananigans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ export
BinaryOutputWriter,
NetCDFOutputWriter,
JLD2OutputWriter,
Checkpointer,
write_output,
read_output,
restore_from_checkpoint,

# Model diagnostics
Diagnostic,
Expand Down
69 changes: 69 additions & 0 deletions src/output_writers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,72 @@ function time_to_write(clock, out::JLD2OutputWriter)
return false
end
end

mutable struct Checkpointer <: OutputWriter
dir :: String
prefix :: String
output_frequency :: Int
end

function Checkpointer(; output_frequency, dir=".", prefix="checkpoint", force=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kwarg 'force' is not used here. It should either be used, or removed from the function signature.

In the JLD2OutputWriter, the kwarg force indicates whether file creation should be 'forced' (it corresponds to the same keyword passed to mkpath.

mkpath(dir)
return Checkpointer(dir, prefix, output_frequency)
end

function savesubfields!(file, model, name, flds=propertynames(getproperty(model, name)))
for f in flds
file["$name/$f"] = Array(getproperty(getproperty(model, name), f).data.parent)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put an if statement that does not save a field if its type is Function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good check. Won't be needed as long we maintain checkpointed_fieldsets but good check either way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should allow users to supply the modelfields that are to be checkpointed. In that case, such a check will be important.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continuing from my comment below, the if-statement can also emit a warning that field x will not be saved.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ali-ramadhan was this resolved?

end
return nothing
end

checkpointed_structs = [:arch, :boundary_conditions, :grid, :clock, :eos, :constants, :closure]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields could be made properties of the Checkpointer. This will allow the user to add/remove fields from checkpointing, which may be generally useful in the future and is currently useful for omitting boundary conditions from checkpointing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ali-ramadhan was this resolved?

checkpointed_fieldsets = [:velocities, :tracers, :G, :Gp]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

G and Gp are not fields of model.

Copy link
Member Author

@ali-ramadhan ali-ramadhan Aug 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. This was before PR #325 was merged and included the tendencies with AdamsBashforthTimestepper. Will update.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ali-ramadhan was this resolved?


function write_output(model, c::Checkpointer)
@warn "Checkpointer will not save forcing functions, output writers, or diagnostics. They will need to be " *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to have a warning that is always printed? Perhaps is it simply better to document this aspect of the checkpointer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, might then be good to just print the warning if there is a non-zero forcing function, or an output writer or diagnostic included.

Then yeah warning may be removed if the checkpointer is well-documented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what I propose:

  1. Allow users to set the modelfields that are to be checkpoints as an argument to the Checkpointer constructor

  2. If one of those fields that the user has asked to be saved contains a function, emit a warning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ali-ramadhan was this resolved?

"restored manually."

filepath = joinpath(c.dir, c.prefix * string(model.clock.iteration) * ".jld2")

jldopen(filepath, "w") do file
# Checkpointing model properties that we can just serialize.
[file["$p"] = getproperty(model, p) for p in checkpointed_structs]

# Checkpointing structs containing fields.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this step can be combined with the first, along with an if statement that inspects the content of the struct to be saved and determines whether it contains Fields or not.

[savesubfields!(file, model, p) for p in checkpointed_fieldsets]
end
end

_arr(::CPU, a) = a
_arr(::GPU, a) = CuArray(a)

function restore_from_checkpoint(filepath)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function must have a ; kwargs... argument which is then merged with the retrieved kwargs from the checkpoint file before being passed to the model constructor. This is needed to restore models with forcing functions or non-default boundary conditions.

@warn "Checkpointer cannot restore forcing functions, output writers, or diagnostics. They will need to be " *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above. What we need is documentation about proper checkpointing, which will include information about how to restore a model from a checkpoint that includes functions. In fact, we could even provide features in the checkpointer that streamline this process (by indicating parts of the model structure that are associated with functions, and asking the user to provide those functions during checkpoint restoration).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah warning doesn't have be in restore_from_checkpoint. Passing forcing functions is a good idea.

"restored manually."

kwargs = Dict{Symbol, Any}() # We'll store all the kwargs we need to initialize a Model.

file = jldopen(filepath, "r")

# Restore model properties that were just serialized.
for p in checkpointed_structs
kwargs[Symbol(p)] = file["$p"]
end

# The Model constructor needs N and L.
kwargs[:N] = (kwargs[:grid].Nx, kwargs[:grid].Ny, kwargs[:grid].Nz)
kwargs[:L] = (kwargs[:grid].Lx, kwargs[:grid].Ly, kwargs[:grid].Lz)

model = Model(; kwargs...)

for p in checkpointed_fieldsets
for subp in propertynames(getproperty(model, p))
getproperty(getproperty(model, p), subp).data.parent .= _arr(model.arch, file["$p/$subp"])
end
end

close(file)

return model
end
54 changes: 54 additions & 0 deletions test/test_output_writers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,65 @@ function run_thermal_bubble_netcdf_tests()
@test all(S .≈ data(model.tracers.S))
end

"""
Run two coarse rising thermal bubble simulations and make sure that when
restarting from a checkpoint, the restarted simulation matches the non-restarted
simulation numerically.
"""
function run_thermal_bubble_checkpointer_tests(arch)
Nx, Ny, Nz = 16, 16, 16
Lx, Ly, Lz = 100, 100, 100
Δt = 6

true_model = Model(N=(Nx, Ny, Nz), L=(Lx, Ly, Lz), ν=4e-2, κ=4e-2, arch=arch)

# Add a cube-shaped warm temperature anomaly that takes up the middle 50%
# of the domain volume.
i1, i2 = round(Int, Nx/4), round(Int, 3Nx/4)
j1, j2 = round(Int, Ny/4), round(Int, 3Ny/4)
k1, k2 = round(Int, Nz/4), round(Int, 3Nz/4)
true_model.tracers.T.data[i1:i2, j1:j2, k1:k2] .+= 0.01

checkpointed_model = deepcopy(true_model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to copy the model before checkpointing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's creating a second model that will be checkpointed as opposed to true_model which isn't.

Probably over paranoid but I wanted the two models to be time-stepped separately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you worried that checkpointing will modify model?

Since this behavior is undesirable / unexpected, perhaps it should simply be tested for, so that we can assume that the model is not modified.


time_step!(true_model, 9, Δt)

checkpointer = Checkpointer(output_frequency=5)
push!(checkpointed_model.output_writers, checkpointer)

# Checkpoint should be saved as "test_model_checkpoint_5.jld" after the 5th iteration.
time_step!(checkpointed_model, 5, Δt)

# Remove all knowledge of the checkpointed model.
checkpointed_model = nothing

restored_model = restore_from_checkpoint("checkpoint5.jld2")

time_step!(restored_model, 4, Δt; adams_bashforth_parameter = n->0.125)

# Now the true_model and restored_model should be identical.
@test all(restored_model.velocities.u.data .≈ true_model.velocities.u.data)
@test all(restored_model.velocities.v.data .≈ true_model.velocities.v.data)
@test all(restored_model.velocities.w.data .≈ true_model.velocities.w.data)
@test all(restored_model.tracers.T.data .≈ true_model.tracers.T.data)
@test all(restored_model.tracers.S.data .≈ true_model.tracers.S.data)
@test all(restored_model.G.Gu.data .≈ true_model.G.Gu.data)
@test all(restored_model.G.Gv.data .≈ true_model.G.Gv.data)
@test all(restored_model.G.Gw.data .≈ true_model.G.Gw.data)
@test all(restored_model.G.GT.data .≈ true_model.G.GT.data)
@test all(restored_model.G.GS.data .≈ true_model.G.GS.data)
end

@testset "Output writers" begin
println("Testing output writers...")

@testset "NetCDF" begin
println(" Testing NetCDF output writer...")
run_thermal_bubble_netcdf_tests()
end

@testset "Checkpointer" begin
println(" Testing Checkpointer...")
run_thermal_bubble_checkpointer_tests(CPU())
end
end