-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #124 from NDF-Poli-USP/dolci/auto_diff
Forward solver adapted for automatic differentiation.
- Loading branch information
Showing
13 changed files
with
599 additions
and
497 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import firedrake as fire | ||
import spyro | ||
from demos.with_automatic_differentiation.utils import \ | ||
model_settings, make_c_camembert | ||
import os | ||
os.environ["OMP_NUM_THREADS"] = "1" | ||
|
||
# --- Basid setup to run a forward simulation with AD --- # | ||
|
||
model = model_settings() | ||
|
||
# Use emsemble parallelism. | ||
M = model["parallelism"]["num_spacial_cores"] | ||
my_ensemble = fire.Ensemble(fire.COMM_WORLD, M) | ||
mesh = fire.UnitSquareMesh(50, 50, comm=my_ensemble.comm) | ||
element = fire.FiniteElement( | ||
model["opts"]["method"], mesh.ufl_cell(), degree=model["opts"]["degree"], | ||
variant=model["opts"]["quadrature"] | ||
) | ||
V = fire.FunctionSpace(mesh, element) | ||
|
||
|
||
forward_solver = spyro.solvers.forward_ad.ForwardSolver(model, mesh, V) | ||
|
||
c_true = make_c_camembert(mesh, V) | ||
# Ricker wavelet | ||
wavelet = spyro.full_ricker_wavelet( | ||
model["timeaxis"]["dt"], model["timeaxis"]["tf"], | ||
model["acquisition"]["frequency"], | ||
) | ||
|
||
if model["parallelism"]["type"] is None: | ||
outfile = fire.VTKFile("solution.pvd") | ||
for sn in range(len(model["acquisition"]["source_pos"])): | ||
rec_data, _ = forward_solver.execute(c_true, sn, wavelet) | ||
sol = forward_solver.solution | ||
outfile.write(sol) | ||
else: | ||
# source_number based on the ensemble.ensemble_comm.rank | ||
source_number = my_ensemble.ensemble_comm.rank | ||
rec_data, _ = forward_solver.execute_acoustic( | ||
c_true, source_number, wavelet) | ||
sol = forward_solver.solution | ||
fire.VTKFile( | ||
"solution_" + str(source_number) + ".pvd", comm=my_ensemble.comm | ||
).write(sol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import firedrake as fire | ||
import firedrake.adjoint as fire_ad | ||
from checkpoint_schedules import Revolve | ||
import spyro | ||
from demos.with_automatic_differentiation import utils | ||
import os | ||
os.environ["OMP_NUM_THREADS"] = "1" | ||
|
||
# --- Basid setup to run a FWI --- # | ||
model = utils.model_settings() | ||
|
||
|
||
def forward( | ||
c, compute_functional=False, true_data_receivers=None, annotate=False | ||
): | ||
"""Time-stepping acoustic forward solver. | ||
The time integration is done using a central difference scheme. | ||
Parameters | ||
---------- | ||
c : firedrake.Function | ||
Velocity field. | ||
compute_functional : bool, optional | ||
Whether to compute the functional. If True, the true receiver | ||
data must be provided. | ||
true_data_receivers : list, optional | ||
True receiver data. This is used to compute the functional. | ||
annotate : bool, optional | ||
If True, the forward model is annotated for automatic differentiation. | ||
Returns | ||
------- | ||
(receiver_data : list, J_val : float) | ||
Receiver data and functional value. | ||
""" | ||
if annotate: | ||
fire_ad.continue_annotation() | ||
if model["aut_dif"]["checkpointing"]: | ||
total_steps = int(model["timeaxis"]["tf"] / model["timeaxis"]["dt"]) | ||
steps_store = int(total_steps / 10) # Store 10% of the steps. | ||
tape = fire_ad.get_working_tape() | ||
tape.progress_bar = fire.ProgressBar | ||
tape.enable_checkpointing(Revolve(total_steps, steps_store)) | ||
|
||
if model["parallelism"]["type"] is None: | ||
outfile = fire.VTKFile("solution.pvd") | ||
receiver_data = [] | ||
J = 0.0 | ||
for sn in range(len(model["acquisition"]["source_pos"])): | ||
rec_data, J_val = forward_solver.execute_acoustic(c, sn, wavelet) | ||
receiver_data.append(rec_data) | ||
J += J_val | ||
sol = forward_solver.solution | ||
outfile.write(sol) | ||
|
||
else: | ||
# source_number based on the ensemble.ensemble_comm.rank | ||
source_number = my_ensemble.ensemble_comm.rank | ||
receiver_data, J = forward_solver.execute_acoustic( | ||
c, source_number, wavelet, | ||
compute_functional=compute_functional, | ||
true_data_receivers=true_data_receivers | ||
) | ||
sol = forward_solver.solution | ||
fire.VTKFile( | ||
"solution_" + str(source_number) + ".pvd", comm=my_ensemble.comm | ||
).write(sol) | ||
|
||
return receiver_data, J | ||
|
||
|
||
# Use emsemble parallelism. | ||
M = model["parallelism"]["num_spacial_cores"] | ||
my_ensemble = fire.Ensemble(fire.COMM_WORLD, M) | ||
mesh = fire.UnitSquareMesh(50, 50, comm=my_ensemble.comm) | ||
element = fire.FiniteElement( | ||
model["opts"]["method"], mesh.ufl_cell(), degree=model["opts"]["degree"], | ||
variant=model["opts"]["quadrature"] | ||
) | ||
V = fire.FunctionSpace(mesh, element) | ||
|
||
|
||
forward_solver = spyro.solvers.forward_ad.ForwardSolver(model, mesh, V) | ||
# Camembert model. | ||
c_true = utils.make_c_camembert(mesh, V) | ||
# Ricker wavelet | ||
wavelet = spyro.full_ricker_wavelet( | ||
model["timeaxis"]["dt"], model["timeaxis"]["tf"], | ||
model["acquisition"]["frequency"], | ||
) | ||
|
||
true_rec, _ = forward(c_true) | ||
|
||
# --- FWI with AD --- # | ||
c_guess = utils.make_c_camembert(mesh, V, c_guess=True) | ||
guess_rec, J = forward( | ||
c_guess, compute_functional=True, true_data_receivers=true_rec, | ||
annotate=True | ||
) | ||
|
||
# :class:`~.EnsembleReducedFunctional` is employed to recompute in | ||
# parallel the functional and its gradient associated with the multiple sources | ||
# (3 in this case). | ||
J_hat = fire_ad.EnsembleReducedFunctional( | ||
J, fire_ad.Control(c_guess), my_ensemble) | ||
c_optimised = fire_ad.minimize(J_hat, method="L-BFGS-B", | ||
options={"disp": True, "maxiter": 10}, | ||
bounds=(1.5, 3.5), | ||
derivative_options={"riesz_representation": 'l2'}) | ||
|
||
fire.VTKFile("c_optimised.pvd").write(c_optimised) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# --- Basid setup to run a forward simulation with AD --- # | ||
import firedrake as fire | ||
import spyro | ||
|
||
def model_settings(): | ||
"""Model settings for forward and Full Waveform Inversion (FWI) | ||
simulations. | ||
Returns | ||
------- | ||
model : dict | ||
Dictionary containing the model settings. | ||
""" | ||
|
||
model = {} | ||
|
||
model["opts"] = { | ||
"method": "KMV", # either CG or mass_lumped_triangle | ||
"quadrature": "KMV", # Equi or mass_lumped_triangle | ||
"degree": 1, # p order | ||
"dimension": 2, # dimension | ||
"regularization": False, # regularization is on? | ||
"gamma": 1e-5, # regularization parameter | ||
} | ||
|
||
model["parallelism"] = { | ||
# options: | ||
# `shots_parallelism`. Shots parallelism. | ||
# None - no shots parallelism. | ||
"type": "shots_parallelism", | ||
"num_spacial_cores": 1, # Number of cores to use in the spatial | ||
# parallelism. | ||
} | ||
|
||
# Define the domain size without the ABL. | ||
model["mesh"] = { | ||
"Lz": 1.0, # depth in km - always positive | ||
"Lx": 1.0, # width in km - always positive | ||
"Ly": 0.0, # thickness in km - always positive | ||
"meshfile": "not_used.msh", | ||
"initmodel": "not_used.hdf5", | ||
"truemodel": "not_used.hdf5", | ||
} | ||
|
||
# Specify a 250-m Absorbing Boundary Layer (ABL) on the three sides of the domain to damp outgoing waves. | ||
model["BCs"] = { | ||
"status": False, # True or False, used to turn on any type of BC | ||
"outer_bc": "non-reflective", # none or non-reflective (outer boundary condition) | ||
"abl_bc": "none", # none, gaussian-taper, or alid | ||
"lz": 0.0, # thickness of the ABL in the z-direction (km) - always positive | ||
"lx": 0.0, # thickness of the ABL in the x-direction (km) - always positive | ||
"ly": 0.0, # thickness of the ABL in the y-direction (km) - always positive | ||
} | ||
|
||
model["acquisition"] = { | ||
"source_type": "Ricker", | ||
"source_pos": spyro.create_transect((0.2, 0.15), (0.8, 0.15), 3), | ||
"frequency": 7.0, | ||
"delay": 1.0, | ||
"receiver_locations": spyro.create_transect((0.2, 0.2), (0.8, 0.2), 10), | ||
} | ||
model["aut_dif"] = { | ||
"status": True, | ||
"checkpointing": True, | ||
} | ||
|
||
model["timeaxis"] = { | ||
"t0": 0.0, # Initial time for event | ||
"tf": 0.8, # Final time for event (for test 7) | ||
"dt": 0.001, # timestep size (divided by 2 in the test 4. dt for test 3 is 0.00050) | ||
"amplitude": 1, # the Ricker has an amplitude of 1. | ||
"nspool": 20, # (20 for dt=0.00050) how frequently to output solution to pvds | ||
"fspool": 1, # how frequently to save solution to RAM | ||
} | ||
|
||
return model | ||
|
||
|
||
def make_c_camembert(mesh, function_space, c_guess=False, plot_c=False): | ||
"""Acoustic velocity model. | ||
Parameters | ||
---------- | ||
mesh : firedrake.Mesh | ||
Mesh. | ||
function_space : firedrake.FunctionSpace | ||
Function space. | ||
c_guess : bool, optional | ||
If True, the initial guess for the velocity field is returned. | ||
plot_c : bool, optional | ||
If True, the velocity field is saved to a VTK file. | ||
""" | ||
x, z = fire.SpatialCoordinate(mesh) | ||
if c_guess: | ||
c = fire.Function(function_space).interpolate(1.5 + 0.0 * x) | ||
else: | ||
c = fire.Function(function_space).interpolate( | ||
2.5 | ||
+ 1 * fire.tanh(100 * (0.125 - fire.sqrt((x - 0.5) ** 2 + (z - 0.5) ** 2))) | ||
) | ||
if plot_c: | ||
outfile = fire.VTKFile("acoustic_cp.pvd") | ||
outfile.write(c) | ||
return c |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.