Skip to content

Commit

Permalink
Add checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Sep 4, 2024
1 parent 58036a0 commit cd35e51
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
16 changes: 12 additions & 4 deletions demos/with_automatic_differentiation/run_fwi_ad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import firedrake as fire
import firedrake.adjoint as fire_ad
from checkpoint_schedules import Revolve
import spyro


Expand Down Expand Up @@ -53,12 +54,12 @@
}
model["aut_dif"] = {
"status": True,
"checkpointing": False,
"checkpointing": True,
}

model["timeaxis"] = {
"t0": 0.0, # Initial time for event
"tf": 0.6, # Final time for event (for test 7)
"tf": 0.8, # Final time for event (for test 7)
"dt": 0.001, # timestep size (divided by 2 in the test 4. dt for test 3 is 0.00050)
"amplitude": 1, # the Ricker has an amplitude of 1.
"nspool": 20, # (20 for dt=0.00050) how frequently to output solution to pvds
Expand Down Expand Up @@ -87,6 +88,12 @@ def forward(
):
if annotate:
fire_ad.continue_annotation()
if model["aut_dif"]["checkpointing"]:
total_steps = int(model["timeaxis"]["tf"] / model["timeaxis"]["dt"])
steps_store = int(total_steps / 10) # Store 10% of the steps.
tape = fire_ad.get_working_tape()
tape.progress_bar = fire.ProgressBar
tape.enable_checkpointing(Revolve(total_steps, steps_store))
if model["parallelism"]["type"] is None:
outfile = fire.VTKFile("solution.pvd")
receiver_data = []
Expand Down Expand Up @@ -147,8 +154,9 @@ def forward(
# (3 in this case).
J_hat = fire_ad.EnsembleReducedFunctional(
J, fire_ad.Control(c_guess), my_ensemble)
fire_ad.taylor_test(J_hat, c_guess, fire.Function(V).assign(1.0))
c_optimised = fire_ad.minimize(J_hat, method="L-BFGS-B",
options={"disp": True, "maxiter": 10},
bounds=(1.5, 3.5),
derivative_options={"riesz_representation": 'l2'})
derivative_options={"riesz_representation": 'l2'})

fire.VTKFile("c_optimised.pvd").write(c_optimised)
11 changes: 10 additions & 1 deletion spyro/solvers/forward_AD.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,16 @@ def execute_acoustic(
J_val = 0.0
receiver_data = []
total_steps = int(self.model["timeaxis"]["tf"] / self.model["timeaxis"]["dt"])
for step in range(total_steps):
if (
fire_ad.get_working_tape()._checkpoint_manager
and self.model["aut_dif"]["checkpointing"]
):
time_range = fire_ad.get_working_tape().timestepper(
iter(range(total_steps)))
else:
time_range = range(total_steps)

for step in time_range:
source_function.assign(wavelet[step] * q_s)
solver.solve()
u_nm1.assign(u_n)
Expand Down

0 comments on commit cd35e51

Please sign in to comment.