Skip to content

Commit

Permalink
Disk checkpointing. (#173)
Browse files Browse the repository at this point in the history
* Disk checkpointing.

---------

Co-authored-by: David A. Ham <david.ham@imperial.ac.uk>
  • Loading branch information
Ig-dolci and dham authored Nov 21, 2024
1 parent c7939a4 commit a8ee848
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 32 deletions.
1 change: 1 addition & 0 deletions pyadjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
annotate_tape, stop_annotating, pause_annotation, continue_annotation)
from .adjfloat import AdjFloat, exp, log
from .reduced_functional import ReducedFunctional
from .checkpointing import disk_checkpointing_callback
from .drivers import compute_gradient, compute_hessian, solve_adjoint
from .verification import taylor_test, taylor_to_dict
from .overloaded_type import OverloadedType, create_overloaded_object
Expand Down
39 changes: 34 additions & 5 deletions pyadjoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import sys
from functools import singledispatchmethod
from checkpoint_schedules import Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType
# A callback interface allowing the user to provide a
# custom error message when disk checkpointing is not configured.
disk_checkpointing_callback = {}


class CheckpointError(RuntimeError):
Expand Down Expand Up @@ -54,8 +57,8 @@ def __init__(self, schedule, tape):
and not tape._package_data
):
raise CheckpointError(
"The schedule employs disk checkpointing but it is not configured."
)
"The schedule employs disk checkpointing but it is not configured.\n"
+ "\n".join(disk_checkpointing_callback.values()))
self.tape = tape
self._schedule = schedule
self.forward_schedule = []
Expand Down Expand Up @@ -152,6 +155,13 @@ def _(self, cp_action, timestep):
# Store the checkpoint data. This is the required data for
# computing the adjoint model from the step `n1`.
_store_adj_dependencies = True
if (
(_store_checkpointable_state or _store_adj_dependencies)
and cp_action.storage == StorageType.DISK
):
for package in self.tape._package_data.values():
package.continue_checkpointing()

self.tape.timesteps[timestep - 1].checkpoint(
_store_checkpointable_state, _store_adj_dependencies)
# Remove unnecessary variables in working memory from previous steps.
Expand All @@ -164,6 +174,11 @@ def _(self, cp_action, timestep):
self.tape.get_blocks().append_step()
if cp_action.write_ics:
self.tape.latest_checkpoint = cp_action.n0

if cp_action.storage == StorageType.DISK:
# Activate disk checkpointing only in the checkpointing process.
for package in self.tape._package_data.values():
package.pause_checkpointing()
return True
else:
return False
Expand All @@ -186,11 +201,15 @@ def recompute(self, functional=None):
if self.mode == Mode.RECORD:
# Finalise the taping process.
self.end_taping()
if self._schedule.uses_storage_type(StorageType.DISK):
# Clear the data of the current state before recomputing.
for package in self.tape._package_data.values():
package.reset()
self.mode = Mode.RECOMPUTE
with self.tape.progress_bar("Evaluating Functional", max=self.total_timesteps) as progress_bar:
# Restore the initial condition to advance the forward model from the step 0.
current_step = self.tape.timesteps[self.forward_schedule[0].n0]
current_step.restore_from_checkpoint()
current_step.restore_from_checkpoint(self.forward_schedule[0].storage)
for cp_action in self.forward_schedule:
self._current_action = cp_action
self.process_operation(cp_action, progress_bar, functional=functional)
Expand Down Expand Up @@ -271,6 +290,12 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
_store_checkpointable_state = True
if cp_action.write_adj_deps:
_store_adj_dependencies = True
if (
(_store_checkpointable_state or _store_adj_dependencies)
and cp_action.storage == StorageType.DISK
):
for package in self.tape._package_data.values():
package.continue_checkpointing()
current_step.checkpoint(
_store_checkpointable_state, _store_adj_dependencies)

Expand All @@ -294,6 +319,10 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
for var in (current_step.checkpointable_state - to_keep):
var._checkpoint = None
step += 1
if cp_action.storage == StorageType.DISK:
# Activate disk checkpointing only in the checkpointing process.
for package in self.tape._package_data.values():
package.pause_checkpointing()

@process_operation.register(Reverse)
def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
Expand Down Expand Up @@ -324,12 +353,12 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
@process_operation.register(Copy)
def _(self, cp_action, progress_bar, **kwargs):
current_step = self.tape.timesteps[cp_action.n]
current_step.restore_from_checkpoint()
current_step.restore_from_checkpoint(cp_action.from_storage)

@process_operation.register(Move)
def _(self, cp_action, progress_bar, **kwargs):
current_step = self.tape.timesteps[cp_action.n]
current_step.restore_from_checkpoint()
current_step.restore_from_checkpoint(cp_action.from_storage)
current_step.delete_checkpoint()

@process_operation.register(EndForward)
Expand Down
39 changes: 29 additions & 10 deletions pyadjoint/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from itertools import chain
from typing import Optional, Iterable
from abc import ABC, abstractmethod
from .checkpointing import CheckpointManager, CheckpointError
from .checkpointing import CheckpointManager, CheckpointError, StorageType

_working_tape = None
_annotation_enabled = False
Expand Down Expand Up @@ -293,7 +293,6 @@ def enable_checkpointing(self, schedule):
Args:
schedule (checkpoint_schedules.schedule): A schedule provided by the
checkpoint_schedules package.
max_n (int, optional): The number of total steps.
"""
if self._blocks:
raise CheckpointError(
Expand Down Expand Up @@ -775,23 +774,31 @@ def checkpoint(self, checkpointable_state, adj_dependencies):
Args:
checkpointable_state (bool): If True, store the checkpointable state
required to restart from the start of a timestep.
adj_dependencies): (bool): If True, store the adjoint dependencies required
adj_dependencies (bool): If True, store the adjoint dependencies required
to compute the adjoint of a timestep.
"""
with stop_annotating():
if checkpointable_state:
for var in self.checkpointable_state:
self._checkpoint[var] = var.checkpoint
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()

if adj_dependencies:
for var in self.adjoint_dependencies:
self._checkpoint[var] = var.checkpoint
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()

def restore_from_checkpoint(self):
def restore_from_checkpoint(self, from_storage):
"""Restore the block var checkpoints from the timestep checkpoint."""

for var in self._checkpoint:
var.checkpoint = self._checkpoint[var]
from .overloaded_type import OverloadedType
for var, checkpoint in self._checkpoint.items():
if (
from_storage == StorageType.DISK
and isinstance(checkpoint, OverloadedType)
):
# checkpoint._ad_restore_checkpoint should be able to restore
# from disk.
var.checkpoint = checkpoint._ad_restore_at_checkpoint(checkpoint)
else:
var.checkpoint = checkpoint

def delete_checkpoint(self):
"""Delete the stored checkpoint references."""
Expand Down Expand Up @@ -881,10 +888,22 @@ def checkpoint(self):

@abstractmethod
def restore_from_checkpoint(self, state):
"""Restore state from a previously stored checkpioint."""
"""Restore state from a previously stored checkpoint."""
pass

@abstractmethod
def copy(self):
"""Produce a new copy of state to be passed to a copy of the tape."""
pass

@abstractmethod
def continue_checkpointing(self):
"""Continue the checkpointing process on disk.
"""
pass

@abstractmethod
def pause_checkpointing(self):
"""Pause the checkpointing process on disk.
"""
pass
37 changes: 25 additions & 12 deletions tests/firedrake_adjoint/test_burgers_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
import numpy as np
set_log_level(CRITICAL)
continue_annotation()
n = 30
mesh = UnitIntervalMesh(n)
V = FunctionSpace(mesh, "CG", 2)
end = 0.3
timestep = Constant(1.0/n)
steps = int(end/float(timestep)) + 1

def basics():
n = 30
mesh = UnitIntervalMesh(n)
end = 0.3
timestep = Constant(1.0/n)
steps = int(end/float(timestep)) + 1
return mesh, timestep, steps

def Dt(u, u_, timestep):
return (u - u_)/timestep


def J(ic, solve_type, checkpointing):
def J(ic, solve_type, timestep, steps, V):
u_ = Function(V)
u = Function(V)
v = TestFunction(V)
Expand Down Expand Up @@ -65,6 +66,7 @@ def J(ic, solve_type, checkpointing):
def test_burgers_newton(solve_type, checkpointing):
"""Adjoint-based gradient tests with and without checkpointing.
"""
mesh, timestep, steps = basics()
tape = get_working_tape()
tape.progress_bar = ProgressBar
if checkpointing:
Expand All @@ -73,13 +75,17 @@ def test_burgers_newton(solve_type, checkpointing):
if checkpointing == "SingleMemory":
schedule = SingleMemoryStorageSchedule()
if checkpointing == "Mixed":
schedule = MixedCheckpointSchedule(steps, steps//3, storage=StorageType.RAM)
enable_disk_checkpointing()
schedule = MixedCheckpointSchedule(steps, steps//3, storage=StorageType.DISK)
if checkpointing == "NoneAdjoint":
schedule = NoneCheckpointSchedule()
tape.enable_checkpointing(schedule)
if schedule.uses_storage_type(StorageType.DISK):
mesh = checkpointable_mesh(mesh)
x, = SpatialCoordinate(mesh)
V = FunctionSpace(mesh, "CG", 2)
ic = project(sin(2. * pi * x), V)
val = J(ic, solve_type, checkpointing)
val = J(ic, solve_type, timestep, steps, V)
if checkpointing:
assert len(tape.timesteps) == steps
Jhat = ReducedFunctional(val, Control(ic))
Expand Down Expand Up @@ -109,13 +115,15 @@ def test_burgers_newton(solve_type, checkpointing):
def test_checkpointing_validity(solve_type, checkpointing):
"""Compare forward and backward results with and without checkpointing.
"""
mesh, timestep, steps = basics()
V = FunctionSpace(mesh, "CG", 2)
# Without checkpointing
tape = get_working_tape()
tape.progress_bar = ProgressBar
x, = SpatialCoordinate(mesh)
ic = project(sin(2.*pi*x), V)

val0 = J(ic, solve_type, False)
val0 = J(ic, solve_type, timestep, steps, V)
Jhat = ReducedFunctional(val0, Control(ic))
dJ0 = Jhat.derivative()
tape.clear_tape()
Expand All @@ -125,8 +133,13 @@ def test_checkpointing_validity(solve_type, checkpointing):
if checkpointing == "Revolve":
tape.enable_checkpointing(Revolve(steps, steps//3))
if checkpointing == "Mixed":
tape.enable_checkpointing(MixedCheckpointSchedule(steps, steps//3, storage=StorageType.RAM))
val1 = J(ic, solve_type, True)
enable_disk_checkpointing()
tape.enable_checkpointing(MixedCheckpointSchedule(steps, steps//3, storage=StorageType.DISK))
mesh = checkpointable_mesh(mesh)
V = FunctionSpace(mesh, "CG", 2)
x, = SpatialCoordinate(mesh)
ic = project(sin(2.*pi*x), V)
val1 = J(ic, solve_type, timestep, steps, V)
Jhat = ReducedFunctional(val1, Control(ic))
assert len(tape.timesteps) == steps
assert np.allclose(val0, val1)
Expand Down
22 changes: 17 additions & 5 deletions tests/firedrake_adjoint/test_disk_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from firedrake import *
from firedrake.__future__ import *
from firedrake.adjoint import *
from firedrake.adjoint_utils.checkpointing import disk_checkpointing
import numpy as np
import os
from checkpoint_schedules import SingleDiskStorageSchedule


def adjoint_example(fine, coarse):
Expand Down Expand Up @@ -54,19 +56,24 @@ def adjoint_example(fine, coarse):
return Jnew, grad_Jnew


def test_disk_checkpointing():
@pytest.mark.parametrize("checkpoint_schedule", [True, False])
def test_disk_checkpointing(checkpoint_schedule):
# Use a Firedrake Tape subclass that supports disk checkpointing.
set_working_tape(Tape())
tape = get_working_tape()
tape.clear_tape()
enable_disk_checkpointing()

if checkpoint_schedule:
tape.enable_checkpointing(SingleDiskStorageSchedule())
fine = checkpointable_mesh(UnitSquareMesh(10, 10, name="fine"))
coarse = checkpointable_mesh(UnitSquareMesh(4, 4, name="coarse"))
J_disk, grad_J_disk = adjoint_example(fine, coarse)

if checkpoint_schedule:
assert disk_checkpointing() is False
tape.clear_tape()
pause_disk_checkpointing()
if not checkpoint_schedule:
pause_disk_checkpointing()

J_mem, grad_J_mem = adjoint_example(fine, coarse)

Expand All @@ -75,5 +82,10 @@ def test_disk_checkpointing():
tape.clear_tape()


if __name__ == "__main__":
test_disk_checkpointing()
def test_disk_checkpointing_error():
tape = get_working_tape()
# check the raise of the exception
with pytest.raises(RuntimeError):
tape.enable_checkpointing(SingleDiskStorageSchedule())
assert disk_checkpointing_callback["firedrake"] == "Please call enable_disk_checkpointing() "\
"before checkpointing on the disk."

0 comments on commit a8ee848

Please sign in to comment.