Skip to content

Commit

Permalink
Merge pull request #124 from NDF-Poli-USP/dolci/auto_diff
Browse files Browse the repository at this point in the history
Forward solver adapted for automatic differentiation.
  • Loading branch information
Olender authored Sep 15, 2024
2 parents 2f19564 + 3ed4881 commit 42b3ad3
Show file tree
Hide file tree
Showing 13 changed files with 599 additions and 497 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,28 @@ jobs:
- uses: actions/checkout@v3
- name: Running serial tests
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
pytest --cov-report=xml --cov=spyro test/
- name: Running parallel 3D forward test
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
mpiexec -n 6 pytest test_3d/test_hexahedral_convergence.py
mpiexec -n 6 pytest test_parallel/test_forward.py
mpiexec -n 6 pytest test_parallel/test_fwi.py
- name: Covering parallel 3D forward test
continue-on-error: true
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
mpiexec -n 6 pytest --cov-report=xml --cov-append --cov=spyro test_3d/test_hexahedral_convergence.py
- name: Covering parallel forward test
continue-on-error: true
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
mpiexec -n 6 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_forward.py
- name: Covering parallel fwi test
continue-on-error: true
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
mpiexec -n 6 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_fwi.py
- name: Uploading coverage to Codecov
run: export CODECOV_TOKEN="6cd21147-54f7-4b77-94ad-4b138053401d" && bash <(curl -s https://codecov.io/bash)
Expand Down
File renamed without changes.
46 changes: 46 additions & 0 deletions demos/with_automatic_differentiation/run_forward_ad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import firedrake as fire
import spyro
from demos.with_automatic_differentiation.utils import \
model_settings, make_c_camembert
import os
os.environ["OMP_NUM_THREADS"] = "1"

# --- Basid setup to run a forward simulation with AD --- #

model = model_settings()

# Use emsemble parallelism.
M = model["parallelism"]["num_spacial_cores"]
my_ensemble = fire.Ensemble(fire.COMM_WORLD, M)
mesh = fire.UnitSquareMesh(50, 50, comm=my_ensemble.comm)
element = fire.FiniteElement(
model["opts"]["method"], mesh.ufl_cell(), degree=model["opts"]["degree"],
variant=model["opts"]["quadrature"]
)
V = fire.FunctionSpace(mesh, element)


forward_solver = spyro.solvers.forward_ad.ForwardSolver(model, mesh, V)

c_true = make_c_camembert(mesh, V)
# Ricker wavelet
wavelet = spyro.full_ricker_wavelet(
model["timeaxis"]["dt"], model["timeaxis"]["tf"],
model["acquisition"]["frequency"],
)

if model["parallelism"]["type"] is None:
outfile = fire.VTKFile("solution.pvd")
for sn in range(len(model["acquisition"]["source_pos"])):
rec_data, _ = forward_solver.execute(c_true, sn, wavelet)
sol = forward_solver.solution
outfile.write(sol)
else:
# source_number based on the ensemble.ensemble_comm.rank
source_number = my_ensemble.ensemble_comm.rank
rec_data, _ = forward_solver.execute_acoustic(
c_true, source_number, wavelet)
sol = forward_solver.solution
fire.VTKFile(
"solution_" + str(source_number) + ".pvd", comm=my_ensemble.comm
).write(sol)
112 changes: 112 additions & 0 deletions demos/with_automatic_differentiation/run_fwi_ad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import firedrake as fire
import firedrake.adjoint as fire_ad
from checkpoint_schedules import Revolve
import spyro
from demos.with_automatic_differentiation import utils
import os
os.environ["OMP_NUM_THREADS"] = "1"

# --- Basid setup to run a FWI --- #
model = utils.model_settings()


def forward(
c, compute_functional=False, true_data_receivers=None, annotate=False
):
"""Time-stepping acoustic forward solver.
The time integration is done using a central difference scheme.
Parameters
----------
c : firedrake.Function
Velocity field.
compute_functional : bool, optional
Whether to compute the functional. If True, the true receiver
data must be provided.
true_data_receivers : list, optional
True receiver data. This is used to compute the functional.
annotate : bool, optional
If True, the forward model is annotated for automatic differentiation.
Returns
-------
(receiver_data : list, J_val : float)
Receiver data and functional value.
"""
if annotate:
fire_ad.continue_annotation()
if model["aut_dif"]["checkpointing"]:
total_steps = int(model["timeaxis"]["tf"] / model["timeaxis"]["dt"])
steps_store = int(total_steps / 10) # Store 10% of the steps.
tape = fire_ad.get_working_tape()
tape.progress_bar = fire.ProgressBar
tape.enable_checkpointing(Revolve(total_steps, steps_store))

if model["parallelism"]["type"] is None:
outfile = fire.VTKFile("solution.pvd")
receiver_data = []
J = 0.0
for sn in range(len(model["acquisition"]["source_pos"])):
rec_data, J_val = forward_solver.execute_acoustic(c, sn, wavelet)
receiver_data.append(rec_data)
J += J_val
sol = forward_solver.solution
outfile.write(sol)

else:
# source_number based on the ensemble.ensemble_comm.rank
source_number = my_ensemble.ensemble_comm.rank
receiver_data, J = forward_solver.execute_acoustic(
c, source_number, wavelet,
compute_functional=compute_functional,
true_data_receivers=true_data_receivers
)
sol = forward_solver.solution
fire.VTKFile(
"solution_" + str(source_number) + ".pvd", comm=my_ensemble.comm
).write(sol)

return receiver_data, J


# Use emsemble parallelism.
M = model["parallelism"]["num_spacial_cores"]
my_ensemble = fire.Ensemble(fire.COMM_WORLD, M)
mesh = fire.UnitSquareMesh(50, 50, comm=my_ensemble.comm)
element = fire.FiniteElement(
model["opts"]["method"], mesh.ufl_cell(), degree=model["opts"]["degree"],
variant=model["opts"]["quadrature"]
)
V = fire.FunctionSpace(mesh, element)


forward_solver = spyro.solvers.forward_ad.ForwardSolver(model, mesh, V)
# Camembert model.
c_true = utils.make_c_camembert(mesh, V)
# Ricker wavelet
wavelet = spyro.full_ricker_wavelet(
model["timeaxis"]["dt"], model["timeaxis"]["tf"],
model["acquisition"]["frequency"],
)

true_rec, _ = forward(c_true)

# --- FWI with AD --- #
c_guess = utils.make_c_camembert(mesh, V, c_guess=True)
guess_rec, J = forward(
c_guess, compute_functional=True, true_data_receivers=true_rec,
annotate=True
)

# :class:`~.EnsembleReducedFunctional` is employed to recompute in
# parallel the functional and its gradient associated with the multiple sources
# (3 in this case).
J_hat = fire_ad.EnsembleReducedFunctional(
J, fire_ad.Control(c_guess), my_ensemble)
c_optimised = fire_ad.minimize(J_hat, method="L-BFGS-B",
options={"disp": True, "maxiter": 10},
bounds=(1.5, 3.5),
derivative_options={"riesz_representation": 'l2'})

fire.VTKFile("c_optimised.pvd").write(c_optimised)
104 changes: 104 additions & 0 deletions demos/with_automatic_differentiation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# --- Basid setup to run a forward simulation with AD --- #
import firedrake as fire
import spyro

def model_settings():
"""Model settings for forward and Full Waveform Inversion (FWI)
simulations.
Returns
-------
model : dict
Dictionary containing the model settings.
"""

model = {}

model["opts"] = {
"method": "KMV", # either CG or mass_lumped_triangle
"quadrature": "KMV", # Equi or mass_lumped_triangle
"degree": 1, # p order
"dimension": 2, # dimension
"regularization": False, # regularization is on?
"gamma": 1e-5, # regularization parameter
}

model["parallelism"] = {
# options:
# `shots_parallelism`. Shots parallelism.
# None - no shots parallelism.
"type": "shots_parallelism",
"num_spacial_cores": 1, # Number of cores to use in the spatial
# parallelism.
}

# Define the domain size without the ABL.
model["mesh"] = {
"Lz": 1.0, # depth in km - always positive
"Lx": 1.0, # width in km - always positive
"Ly": 0.0, # thickness in km - always positive
"meshfile": "not_used.msh",
"initmodel": "not_used.hdf5",
"truemodel": "not_used.hdf5",
}

# Specify a 250-m Absorbing Boundary Layer (ABL) on the three sides of the domain to damp outgoing waves.
model["BCs"] = {
"status": False, # True or False, used to turn on any type of BC
"outer_bc": "non-reflective", # none or non-reflective (outer boundary condition)
"abl_bc": "none", # none, gaussian-taper, or alid
"lz": 0.0, # thickness of the ABL in the z-direction (km) - always positive
"lx": 0.0, # thickness of the ABL in the x-direction (km) - always positive
"ly": 0.0, # thickness of the ABL in the y-direction (km) - always positive
}

model["acquisition"] = {
"source_type": "Ricker",
"source_pos": spyro.create_transect((0.2, 0.15), (0.8, 0.15), 3),
"frequency": 7.0,
"delay": 1.0,
"receiver_locations": spyro.create_transect((0.2, 0.2), (0.8, 0.2), 10),
}
model["aut_dif"] = {
"status": True,
"checkpointing": True,
}

model["timeaxis"] = {
"t0": 0.0, # Initial time for event
"tf": 0.8, # Final time for event (for test 7)
"dt": 0.001, # timestep size (divided by 2 in the test 4. dt for test 3 is 0.00050)
"amplitude": 1, # the Ricker has an amplitude of 1.
"nspool": 20, # (20 for dt=0.00050) how frequently to output solution to pvds
"fspool": 1, # how frequently to save solution to RAM
}

return model


def make_c_camembert(mesh, function_space, c_guess=False, plot_c=False):
"""Acoustic velocity model.
Parameters
----------
mesh : firedrake.Mesh
Mesh.
function_space : firedrake.FunctionSpace
Function space.
c_guess : bool, optional
If True, the initial guess for the velocity field is returned.
plot_c : bool, optional
If True, the velocity field is saved to a VTK file.
"""
x, z = fire.SpatialCoordinate(mesh)
if c_guess:
c = fire.Function(function_space).interpolate(1.5 + 0.0 * x)
else:
c = fire.Function(function_space).interpolate(
2.5
+ 1 * fire.tanh(100 * (0.125 - fire.sqrt((x - 0.5) ** 2 + (z - 0.5) ** 2)))
)
if plot_c:
outfile = fire.VTKFile("acoustic_cp.pvd")
outfile.write(c)
return c
2 changes: 2 additions & 0 deletions spyro/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from .acoustic_wave import AcousticWave
from .mms_acoustic import AcousticWaveMMS
from .inversion import FullWaveformInversion
from .forward_ad import ForwardSolver

__all__ = [
"Wave",
"AcousticWave",
"AcousticWaveMMS",
"FullWaveformInversion",
"ForwardSolver",
]
Loading

0 comments on commit 42b3ad3

Please sign in to comment.