From the plot above, we can notice two things. First, it is rotated by 90 degrees. In cases like this, the saved PNG image will be in the correct orientation. However, the plotted image in this notebook is rotated because the velocity model we used, commonly with segy files, has the Z-axis before the X-axis in data input.
-
Another observation is that, unlike in the simple forward tutorial, we looked at the experiment layout after running the forward solve. For memory purposes, a 2D or 3D velocity file is only actually interpolated into our domain when it is necessary for another method of the Wave object class. If you need to force the interpolation sooner, call the _get_initial_velocity_model() method.
+
Another observation is that, unlike in the simple forward tutorial, we looked at the experiment layout after running the forward solve. For memory purposes, a 2D or 3D velocity file is only actually interpolated into our domain when it is necessary for another method of the Wave object class. If you need to force the interpolation sooner, call the _initialize_model_parameters() method.
It is also important to note that even though receivers look like a line, they are actually located in points, which can be visible by zooming into the image, not coinciding with nodes.
diff --git a/notebook_tutorials/simple_forward_with_overthrust.ipynb b/notebook_tutorials/simple_forward_with_overthrust.ipynb
index f51e3c50..69d76d55 100644
--- a/notebook_tutorials/simple_forward_with_overthrust.ipynb
+++ b/notebook_tutorials/simple_forward_with_overthrust.ipynb
@@ -397,7 +397,7 @@
"source": [
"From the plot above, we can notice two things. First, it is rotated by 90 degrees. In cases like this, the saved PNG image will be in the correct orientation. However, the plotted image in this notebook is rotated because the velocity model we used, commonly with segy files, has the Z-axis before the X-axis in data input.\n",
"\n",
- "Another observation is that, unlike in the **simple forward tutorial**, we looked at the experiment layout after running the forward solve. For memory purposes, a 2D or 3D velocity file is only actually interpolated into our domain when it is necessary for another method of the Wave object class. If you need to force the interpolation sooner, call the _get_initial_velocity_model() method.\n",
+ "Another observation is that, unlike in the **simple forward tutorial**, we looked at the experiment layout after running the forward solve. For memory purposes, a 2D or 3D velocity file is only actually interpolated into our domain when it is necessary for another method of the Wave object class. If you need to force the interpolation sooner, call the _initialize_model_parameters() method.\n",
"\n",
"It is also important to note that even though receivers look like a line, they are actually located in points, which can be visible by zooming into the image, not coinciding with nodes."
]
diff --git a/spyro/io/model_parameters.py b/spyro/io/model_parameters.py
index 95377d31..89e86037 100644
--- a/spyro/io/model_parameters.py
+++ b/spyro/io/model_parameters.py
@@ -608,9 +608,9 @@ def _sanitize_optimization_and_velocity(self):
if "velocity_conditional" not in dictionary["synthetic_data"]:
self.velocity_model_type = None
warnings.warn(
- "No velocity model set initially. If using \
- user defined conditional or expression, please \
- input it in the Wave object."
+ "No velocity model set initially. If using " \
+ "user defined conditional or expression, please " \
+ "input it in the Wave object."
)
if "velocity_conditional" in dictionary["synthetic_data"]:
diff --git a/spyro/solvers/acoustic_wave.py b/spyro/solvers/acoustic_wave.py
index 1fb5a373..5c6a30c0 100644
--- a/spyro/solvers/acoustic_wave.py
+++ b/spyro/solvers/acoustic_wave.py
@@ -40,7 +40,7 @@ def forward_solve(self):
if self.function_space is None:
self.force_rebuild_function_space()
- self._get_initial_velocity_model()
+ self._initialize_model_parameters()
self.c = self.initial_velocity_model
self.matrix_building()
self.wave_propagator()
@@ -144,3 +144,33 @@ def reset_pressure(self):
self.u_n.assign(0.0)
except:
warnings.warn("No pressure to reset")
+
+ #@override
+ def _initialize_model_parameters(self):
+ if self.initial_velocity_model is not None:
+ return None
+
+ if self.initial_velocity_model_file is None:
+ raise ValueError("No velocity model or velocity file to load.")
+
+ if self.initial_velocity_model_file.endswith(".segy"):
+ vp_filename, vp_filetype = os.path.splitext(
+ self.initial_velocity_model_file
+ )
+ warnings.warn("Converting segy file to hdf5")
+ write_velocity_model(
+ self.initial_velocity_model_file, ofname=vp_filename
+ )
+ self.initial_velocity_model_file = vp_filename + ".hdf5"
+
+ if self.initial_velocity_model_file.endswith((".hdf5", ".h5")):
+ self.initial_velocity_model = interpolate(
+ self,
+ self.initial_velocity_model_file,
+ self.function_space.sub(0),
+ )
+
+ if self.debug_output:
+ fire.File("initial_velocity_model.pvd").write(
+ self.initial_velocity_model, name="velocity"
+ )
\ No newline at end of file
diff --git a/spyro/solvers/elastic_wave/elastic_wave.py b/spyro/solvers/elastic_wave/elastic_wave.py
index e69de29b..d18d2c0e 100644
--- a/spyro/solvers/elastic_wave/elastic_wave.py
+++ b/spyro/solvers/elastic_wave/elastic_wave.py
@@ -0,0 +1,29 @@
+from abc import abstractmethod
+
+from ..wave import Wave
+
+class ElasticWave(Wave):
+ '''Base class for elastic wave propagators'''
+ def __init__(self, dictionary, comm=None):
+ super().__init__(dictionary, comm=comm)
+
+ #@override
+ def _initialize_model_parameters(self):
+ d = self.input_dictionary.get("synthetic_data", False)
+ if bool(d) and "type" in d:
+ if d["type"] == "object":
+ self.initialize_model_parameters_from_object(d)
+ elif d["type"] == "file":
+ self.initialize_model_parameters_from_file(d)
+ else:
+ raise Exception(f"Invalid synthetic data type: {d['type']}")
+ else:
+ raise Exception("Input dictionary must contain ['synthetic_data']['type']")
+
+ @abstractmethod
+ def initialize_model_parameters_from_object(self, synthetic_data_dict):
+ pass
+
+ @abstractmethod
+ def initialize_model_parameters_from_file(self, synthetic_data_dict):
+ pass
diff --git a/spyro/solvers/elastic_wave/isotropic_wave.py b/spyro/solvers/elastic_wave/isotropic_wave.py
new file mode 100644
index 00000000..52caf4e7
--- /dev/null
+++ b/spyro/solvers/elastic_wave/isotropic_wave.py
@@ -0,0 +1,48 @@
+from .elastic_wave import ElasticWave
+
+class IsotropicWave(ElasticWave):
+ '''Isotropic elastic wave propagator'''
+ def __init__(self, dictionary, comm=None):
+ super().__init__(dictionary, comm=comm)
+
+ self.rho = None # Density
+ self.lmbda = None # First Lame parameter
+ self.mu = None # Second Lame parameter
+ self.c_s = None # Secondary wave velocity
+
+ #@override
+ def initialize_model_parameters_from_object(self, synthetic_data_dict: dict):
+ self.rho = synthetic_data_dict.get("density", None)
+ self.lmbda = synthetic_data_dict.get("lambda",
+ synthetic_data_dict.get("lame_first", None))
+ self.mu = synthetic_data_dict.get("mu",
+ synthetic_data_dict.get("lame_second", None))
+ self.c = synthetic_data_dict.get("p_wave_velocity", None)
+ self.c_s = synthetic_data_dict.get("s_wave_velocity", None)
+
+ # Check if {rho, lambda, mu} is set and {c, c_s} are not
+ option_1 = bool(self.rho) and \
+ bool(self.lmbda) and \
+ bool(self.mu) and \
+ not bool(self.c) and \
+ not bool(self.c_s)
+ # Check if {rho, c, c_s} is set and {lambda, mu} are not
+ option_2 = bool(self.rho) and \
+ bool(self.c) and \
+ bool(self.c_s) and \
+ not bool(self.lmbda) and \
+ not bool(self.mu)
+
+ if not option_1 and not option_2:
+ raise Exception(f"Inconsistent selection of isotropic elastic wave parameters:\n" \
+ f" Density : {bool(self.rho)}\n"\
+ f" Lame first : {bool(self.lmbda)}\n"\
+ f" Lame second : {bool(self.mu)}\n"\
+ f" P-wave velocity: {bool(self.c)}\n"\
+ f" S-wave velocity: {bool(self.c_s)}\n"\
+ "The valid options are \{Density, Lame first, Lame second\} "\
+ "or \{Density, P-wave velocity, S-wave velocity\}")
+
+ #@override
+ def initialize_model_parameters_from_file(self, synthetic_data_dict):
+ raise NotImplementedError
diff --git a/spyro/solvers/time_integration_central_difference.py b/spyro/solvers/time_integration_central_difference.py
index 410bf5cd..603f92c5 100644
--- a/spyro/solvers/time_integration_central_difference.py
+++ b/spyro/solvers/time_integration_central_difference.py
@@ -79,8 +79,8 @@ def central_difference(Wave_object, source_id=0):
if (step - 1) % Wave_object.output_frequency == 0:
assert (
fire.norm(u_n) < 1
- ), "Numerical instability. Try reducing dt or building the \
- mesh differently"
+ ), "Numerical instability. Try reducing dt or building the " \
+ "mesh differently"
if Wave_object.forward_output:
output.write(u_n, time=t, name="Pressure")
@@ -184,8 +184,8 @@ def mixed_space_central_difference(Wave_object, source_id=0):
if (step - 1) % Wave_object.output_frequency == 0:
assert (
fire.norm(X_np1.sub(0)) < 1
- ), "Numerical instability. Try reducing dt or building the \
- mesh differently"
+ ), "Numerical instability. Try reducing dt or building the " \
+ "mesh differently"
if Wave_object.forward_output:
output.write(X_np1.sub(0), time=t, name="Pressure")
@@ -292,8 +292,8 @@ def central_difference_MMS(Wave_object, source_id=0):
if (step - 1) % Wave_object.output_frequency == 0:
assert (
fire.norm(u_n) < 1
- ), "Numerical instability. Try reducing dt or building the \
- mesh differently"
+ ), "Numerical instability. Try reducing dt or building the " \
+ "mesh differently"
if Wave_object.forward_output:
output.write(u_n, time=t, name="Pressure")
if t > 0:
diff --git a/spyro/solvers/wave.py b/spyro/solvers/wave.py
index 4351d17c..be2baf5e 100644
--- a/spyro/solvers/wave.py
+++ b/spyro/solvers/wave.py
@@ -209,8 +209,8 @@ def set_initial_velocity_model(
self.initial_velocity_model = vp
else:
raise ValueError(
- "Please specify either a conditional, expression, firedrake \
- function or new file name (segy or hdf5)."
+ "Please specify either a conditional, expression, firedrake " \
+ "function or new file name (segy or hdf5)."
)
if output:
fire.File("initial_velocity_model.pvd").write(
@@ -223,35 +223,10 @@ def _map_sources_and_receivers(self):
else:
self.sources = None
self.receivers = Receivers(self)
-
- def _get_initial_velocity_model(self):
- if self.initial_velocity_model is not None:
- return None
-
- if self.initial_velocity_model_file is None:
- raise ValueError("No velocity model or velocity file to load.")
-
- if self.initial_velocity_model_file.endswith(".segy"):
- vp_filename, vp_filetype = os.path.splitext(
- self.initial_velocity_model_file
- )
- warnings.warn("Converting segy file to hdf5")
- write_velocity_model(
- self.initial_velocity_model_file, ofname=vp_filename
- )
- self.initial_velocity_model_file = vp_filename + ".hdf5"
-
- if self.initial_velocity_model_file.endswith(".hdf5"):
- self.initial_velocity_model = interpolate(
- self,
- self.initial_velocity_model_file,
- self.function_space.sub(0),
- )
-
- if self.debug_output:
- fire.File("initial_velocity_model.pvd").write(
- self.initial_velocity_model, name="velocity"
- )
+
+ @abstractmethod
+ def _initialize_model_parameters(self):
+ pass
def _build_function_space(self):
self.function_space = FE_method(self.mesh, self.method, self.degree)
diff --git a/spyro/tools/cells_per_wavelength_calculator.py b/spyro/tools/cells_per_wavelength_calculator.py
index 8c92391c..af6e7895 100644
--- a/spyro/tools/cells_per_wavelength_calculator.py
+++ b/spyro/tools/cells_per_wavelength_calculator.py
@@ -347,7 +347,7 @@ def find_minimum(self, starting_cpw=None, TOL=None, accuracy=None, savetxt=False
# Running forward model
Wave_obj = self.build_current_object(cpw)
- Wave_obj._get_initial_velocity_model()
+ Wave_obj._initialize_model_parameters() # TO REVIEW: call to protected method
# Setting up time-step
if self.timestep_calculation != "float":
diff --git a/test/test_isotropic_wave.py b/test/test_isotropic_wave.py
new file mode 100644
index 00000000..ed5a3ccb
--- /dev/null
+++ b/test/test_isotropic_wave.py
@@ -0,0 +1,67 @@
+import pytest
+
+from spyro.solvers.elastic_wave.isotropic_wave import IsotropicWave
+
+# TO REVIEW: it is extra work to have to define this dictionary everytime
+# Here I listed only the required parameters for running to get a view of
+# what is currently necessary. Note that the dictionary is not even complete
+dummy_dict = {
+ "options": {
+ "cell_type": "T",
+ "variant": "lumped",
+ },
+ "time_axis": {
+ "final_time": 1,
+ "dt": 0.001,
+ "output_frequency": 100,
+ "gradient_sampling_frequency": 1,
+ },
+ "mesh": {},
+ "acquisition": {
+ "receiver_locations": [],
+ "source_type": "ricker",
+ "source_locations": [(0, 0)],
+ "frequency": 5.0,
+ },
+}
+
+def test_initialize_model_parameters_from_object_missing_parameters():
+ synthetic_dict = {
+ "type": "object",
+ }
+ wave = IsotropicWave(dummy_dict)
+ with pytest.raises(Exception) as e:
+ wave.initialize_model_parameters_from_object(synthetic_dict)
+
+def test_initialize_model_parameters_from_object_first_option():
+ synthetic_dict = {
+ "type": "object",
+ "density": 1,
+ "lambda": 2,
+ "mu": 3,
+ }
+ wave = IsotropicWave(dummy_dict)
+ wave.initialize_model_parameters_from_object(synthetic_dict)
+
+def test_initialize_model_parameters_from_object_second_option():
+ synthetic_dict = {
+ "type": "object",
+ "density": 1,
+ "p_wave_velocity": 2,
+ "s_wave_velocity": 3,
+ }
+ wave = IsotropicWave(dummy_dict)
+ wave.initialize_model_parameters_from_object(synthetic_dict)
+
+def test_initialize_model_parameters_from_object_redundant():
+ synthetic_dict = {
+ "type": "object",
+ "density": 1,
+ "lmbda": 2,
+ "mu": 3,
+ "p_wave_velocity": 2,
+ "s_wave_velocity": 3,
+ }
+ wave = IsotropicWave(dummy_dict)
+ with pytest.raises(Exception) as e:
+ wave.initialize_model_parameters_from_object(synthetic_dict)
\ No newline at end of file
From 9f3edbd8f32e2c671785a2a437f940cb9523b596 Mon Sep 17 00:00:00 2001
From: Eduardo Moscatelli de Souza <5752216+SouzaEM@users.noreply.github.com>
Date: Tue, 2 Jul 2024 10:36:17 -0300
Subject: [PATCH 3/5] Refactor time integration code to merge
'central_difference' and 'mixed_space_central_difference' function into a
unique function
---
spyro/examples/camembert.py | 4 +
spyro/examples/rectangle.py | 25 +--
spyro/solvers/acoustic_wave.py | 69 +++++-
spyro/solvers/dg_wave.py | 63 ------
spyro/solvers/forward.py | 181 ---------------
spyro/solvers/forward_AD.py | 180 ---------------
spyro/solvers/gradient_old.py | 210 ------------------
spyro/solvers/time_integration.py | 9 +-
.../time_integration_central_difference.py | 202 ++++-------------
spyro/solvers/wave.py | 48 ++++
spyro/utils/typing.py | 6 +
11 files changed, 169 insertions(+), 828 deletions(-)
delete mode 100644 spyro/solvers/dg_wave.py
delete mode 100644 spyro/solvers/forward.py
delete mode 100644 spyro/solvers/forward_AD.py
delete mode 100644 spyro/solvers/gradient_old.py
create mode 100644 spyro/utils/typing.py
diff --git a/spyro/examples/camembert.py b/spyro/examples/camembert.py
index 2abb71b0..0fc6b69b 100644
--- a/spyro/examples/camembert.py
+++ b/spyro/examples/camembert.py
@@ -156,3 +156,7 @@ def _camembert_velocity_model(self):
)
self.set_initial_velocity_model(conditional=cond, dg_velocity_model=False)
return None
+
+if __name__ == "__main__":
+ wave = Camembert_acoustic()
+ wave.forward_solve()
diff --git a/spyro/examples/rectangle.py b/spyro/examples/rectangle.py
index fc82c892..ec87cd0a 100644
--- a/spyro/examples/rectangle.py
+++ b/spyro/examples/rectangle.py
@@ -188,25 +188,6 @@ def multiple_layer_velocity_model(self, z_switch, layers):
# cond = fire.conditional(self.mesh_z > z_switch, layer1, layer2)
self.set_initial_velocity_model(conditional=cond)
-
-# class Rectangle(AcousticWave):
-# def __init__(self, model_dictionary=None, comm=None):
-# model_parameters = Rectangle_parameters(
-# dictionary=model_dictionary, comm=comm
-# )
-# super().__init__(
-# model_parameters=model_parameters, comm=model_parameters.comm
-# )
-# comm = self.comm
-# num_sources = self.number_of_sources
-# if comm.comm.rank == 0 and comm.ensemble_comm.rank == 0:
-# print(
-# "INFO: Distributing %d shot(s) across %d core(s). \
-# Each shot is using %d cores"
-# % (
-# num_sources,
-# fire.COMM_WORLD.size,
-# fire.COMM_WORLD.size / comm.ensemble_comm.size,
-# ),
-# flush=True,
-# )
+if __name__ == "__main__":
+ wave = Rectangle_acoustic()
+ wave.forward_solve()
diff --git a/spyro/solvers/acoustic_wave.py b/spyro/solvers/acoustic_wave.py
index 5c6a30c0..18f867fd 100644
--- a/spyro/solvers/acoustic_wave.py
+++ b/spyro/solvers/acoustic_wave.py
@@ -14,7 +14,7 @@
from .backward_time_integration import (
backward_wave_propagator,
)
-
+from ..utils.typing import override
class AcousticWave(Wave):
def save_current_velocity_model(self, file_name=None):
@@ -68,6 +68,7 @@ def matrix_building(self):
self.trial_function = None
self.u_nm1 = None
self.u_n = None
+ self.u_np1 = fire.Function(self.function_space)
self.lhs = None
self.solver = None
self.rhs = None
@@ -81,6 +82,7 @@ def matrix_building(self):
self.X = None
self.X_n = None
self.X_nm1 = None
+ self.X_np1 = fire.Function(V * Z)
construct_solver_or_matrix_with_pml(self)
@ensemble_propagator
@@ -145,7 +147,7 @@ def reset_pressure(self):
except:
warnings.warn("No pressure to reset")
- #@override
+ @override
def _initialize_model_parameters(self):
if self.initial_velocity_model is not None:
return None
@@ -173,4 +175,65 @@ def _initialize_model_parameters(self):
if self.debug_output:
fire.File("initial_velocity_model.pvd").write(
self.initial_velocity_model, name="velocity"
- )
\ No newline at end of file
+ )
+
+ @override
+ def _set_vstate(self, vstate):
+ if self.abc_boundary_layer_type == "PML":
+ self.X_n.assign(vstate)
+ else:
+ self.u_n.assign(vstate)
+
+ @override
+ def _get_vstate(self):
+ if self.abc_boundary_layer_type == "PML":
+ return self.X_n
+ else:
+ return self.u_n
+
+ @override
+ def _set_prev_vstate(self, vstate):
+ if self.abc_boundary_layer_type == "PML":
+ self.X_nm1.assign(vstate)
+ else:
+ self.u_nm1.assign(vstate)
+
+ @override
+ def _get_prev_vstate(self):
+ if self.abc_boundary_layer_type == "PML":
+ return self.X_nm1
+ else:
+ return self.u_nm1
+
+ @override
+ def _set_next_vstate(self, vstate):
+ if self.abc_boundary_layer_type == "PML":
+ self.X_np1.assign(vstate)
+ else:
+ self.u_np1.assign(vstate)
+
+ @override
+ def _get_next_vstate(self):
+ if self.abc_boundary_layer_type == "PML":
+ return self.X_np1
+ else:
+ return self.u_np1
+
+ @override
+ def get_receivers_output(self):
+ if self.abc_boundary_layer_type == "PML":
+ data_with_halos = self.X_n.dat.data_ro_with_halos[0][:]
+ else:
+ data_with_halos = self.u_n.dat.data_ro_with_halos[:]
+ return self.receivers.interpolate(data_with_halos)
+
+ @override
+ def get_function(self):
+ if self.abc_boundary_layer_type == "PML":
+ return self.X_n.sub(0)
+ else:
+ return self.u_n
+
+ @override
+ def get_function_name(self):
+ return "Pressure"
\ No newline at end of file
diff --git a/spyro/solvers/dg_wave.py b/spyro/solvers/dg_wave.py
deleted file mode 100644
index 0abe5e89..00000000
--- a/spyro/solvers/dg_wave.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# import firedrake as fire
-# from firedrake import dot, grad, jump, avg, dx, ds, dS, Constant
-# from spyro import Wave
-
-# fire.set_log_level(fire.ERROR)
-
-
-# class DG_Wave(Wave):
-# def matrix_building(self):
-# """Builds solver operators. Doesn't create mass matrices if
-# matrix_free option is on,
-# which it is by default.
-# """
-# V = self.function_space
-# # Trial and test functions
-# u = fire.TrialFunction(V)
-# v = fire.TestFunction(V)
-
-# # # Previous functions for time integration
-# u_n = fire.Function(V)
-# u_nm1 = fire.Function(V)
-# self.u_nm1 = u_nm1
-# self.u_n = u_n
-# c = self.c
-
-# self.current_time = 0.0
-# dt = self.dt
-
-# # Normal component, cell size and right-hand side
-# h = fire.CellDiameter(self.mesh)
-# h_avg = (h("+") + h("-")) / 2
-# n = fire.FacetNormal(self.mesh)
-
-# # Parameters
-# alpha = 4.0
-# gamma = 8.0
-
-# # Bilinear form
-# a = (
-# dot(grad(v), grad(u)) * dx
-# - dot(avg(grad(v)), jump(u, n)) * dS
-# - dot(jump(v, n), avg(grad(u))) * dS
-# + alpha / h_avg * dot(jump(v, n), jump(u, n)) * dS
-# - dot(grad(v), u * n) * ds
-# - dot(v * n, grad(u)) * ds
-# + (gamma / h) * v * u * ds
-# + ((u) / Constant(dt**2)) / c * v * dx
-# )
-# # Linear form
-# b = ((2.0 * u_n - u_nm1) / Constant(dt**2)) / c * v * dx
-# form = a - b
-
-# lhs = fire.lhs(form)
-# rhs = fire.rhs(form)
-
-# A = fire.assemble(lhs)
-# params = {"ksp_type": "gmres"}
-# self.solver = fire.LinearSolver(A, solver_parameters=params)
-
-# # lterar para como o thiago fez
-# self.rhs = rhs
-# B = fire.Function(V)
-# self.B = B
diff --git a/spyro/solvers/forward.py b/spyro/solvers/forward.py
deleted file mode 100644
index 06d38e85..00000000
--- a/spyro/solvers/forward.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# import firedrake as fire
-# from firedrake.assemble import create_assembly_callable
-# from firedrake import Constant, dx, dot, inner, grad, ds
-# import FIAT
-# import finat
-
-# from .. import utils
-# from ..domains import quadrature, space
-# from ..pml import damping
-# from ..io import ensemble_forward
-# from . import helpers
-
-
-# def gauss_lobatto_legendre_line_rule(degree):
-# fiat_make_rule = FIAT.quadrature.GaussLobattoLegendreQuadratureLineRule
-# fiat_rule = fiat_make_rule(FIAT.ufc_simplex(1), degree + 1)
-# finat_ps = finat.point_set.GaussLobattoLegendrePointSet
-# finat_qr = finat.quadrature.QuadratureRule
-# return finat_qr(finat_ps(fiat_rule.get_points()), fiat_rule.get_weights())
-
-
-# # 3D
-# def gauss_lobatto_legendre_cube_rule(dimension, degree):
-# """Returns GLL integration rule
-
-# Parameters
-# ----------
-# dimension: `int`
-# The dimension of the mesh
-# degree: `int`
-# The degree of the function space
-
-# Returns
-# -------
-# result: `finat.quadrature.QuadratureRule`
-# The GLL integration rule
-# """
-# make_tensor_rule = finat.quadrature.TensorProductQuadratureRule
-# result = gauss_lobatto_legendre_line_rule(degree)
-# for _ in range(1, dimension):
-# line_rule = gauss_lobatto_legendre_line_rule(degree)
-# result = make_tensor_rule([result, line_rule])
-# return result
-
-
-# @ensemble_forward
-# def forward(
-# model,
-# mesh,
-# comm,
-# c,
-# excitations,
-# wavelet,
-# receivers,
-# source_num=0,
-# output=False,
-# ):
-# """Secord-order in time fully-explicit scheme
-# with implementation of a Perfectly Matched Layer (PML) using
-# CG FEM with or without higher order mass lumping (KMV type elements).
-
-# Parameters
-# ----------
-# model: Python `dictionary`
-# Contains model options and parameters
-# mesh: Firedrake.mesh object
-# The 2D/3D triangular mesh
-# comm: Firedrake.ensemble_communicator
-# The MPI communicator for parallelism
-# c: Firedrake.Function
-# The velocity model interpolated onto the mesh.
-# excitations: A list Firedrake.Functions
-# wavelet: array-like
-# Time series data that's injected at the source location.
-# receivers: A :class:`spyro.Receivers` object.
-# Contains the receiver locations and sparse interpolation methods.
-# source_num: `int`, optional
-# The source number you wish to simulate
-# output: `boolean`, optional
-# Whether or not to write results to pvd files.
-
-# Returns
-# -------
-# usol: list of Firedrake.Functions
-# The full field solution at `fspool` timesteps
-# usol_recv: array-like
-# The solution interpolated to the receivers at all timesteps
-
-# """
-
-# method = model["opts"]["method"]
-# degree = model["opts"]["degree"]
-# dt = model["timeaxis"]["dt"]
-# final_time = model["timeaxis"]["tf"]
-# nspool = model["timeaxis"]["nspool"]
-# fspool = model["timeaxis"]["fspool"]
-# excitations.current_source = source_num
-
-# nt = int(final_time / dt) # number of timesteps
-
-# element = fire.FiniteElement(method, mesh.ufl_cell(), degree=degree)
-
-# V = fire.FunctionSpace(mesh, element)
-
-# # typical CG FEM in 2d/3d
-# u = fire.TrialFunction(V)
-# v = fire.TestFunction(V)
-
-# u_nm1 = fire.Function(V)
-# u_n = fire.Function(V, name="pressure")
-# u_np1 = fire.Function(V)
-
-# if output:
-# outfile = helpers.create_output_file("forward.pvd", comm, source_num)
-
-# t = 0.0
-
-# # -------------------------------------------------------
-# m1 = ((u) / Constant(dt**2)) * v * dx
-# a = (
-# c * c * dot(grad(u_n), grad(v)) * dx
-# + ((-2.0 * u_n + u_nm1) / Constant(dt**2)) * v * dx
-# ) # explicit
-
-# X = fire.Function(V)
-# B = fire.Function(V)
-
-# lhs = m1
-# rhs = -a
-
-# A = fire.assemble(lhs)
-# solver = fire.LinearSolver(A)
-
-# usol = [
-# fire.Function(V, name="pressure") for t in range(nt) if t % fspool == 0
-# ]
-# usol_recv = []
-# save_step = 0
-
-# assembly_callable = create_assembly_callable(rhs, tensor=B)
-
-# rhs_forcing = fire.Function(V)
-
-# for step in range(nt):
-# rhs_forcing.assign(0.0)
-# assembly_callable()
-# f = excitations.apply_source(rhs_forcing, wavelet[step])
-# B0 = B.sub(0)
-# B0 += f
-# solver.solve(X, B)
-
-# u_np1.assign(X)
-
-# usol_recv.append(
-# receivers.interpolate(u_np1.dat.data_ro_with_halos[:])
-# )
-
-# if step % fspool == 0:
-# usol[save_step].assign(u_np1)
-# save_step += 1
-
-# if step % nspool == 0:
-# assert (
-# fire.norm(u_n) < 1
-# ), "Numerical instability. Try reducing dt or building the mesh differently"
-# if output:
-# outfile.write(u_n, time=t, name="Pressure")
-# if t > 0:
-# helpers.display_progress(comm, t)
-
-# u_nm1.assign(u_n)
-# u_n.assign(u_np1)
-
-# t = step * float(dt)
-
-# usol_recv = helpers.fill(
-# usol_recv, receivers.is_local, nt, receivers.num_receivers
-# )
-# usol_recv = utils.communicate(usol_recv, comm)
-
-# return usol, usol_recv
diff --git a/spyro/solvers/forward_AD.py b/spyro/solvers/forward_AD.py
deleted file mode 100644
index a8e257f3..00000000
--- a/spyro/solvers/forward_AD.py
+++ /dev/null
@@ -1,180 +0,0 @@
-# from firedrake import *
-
-# # from .. import utils
-# from ..domains import quadrature, space
-
-# # from ..pml import damping
-# # from ..io import ensemble_forward
-# from . import helpers
-
-# # Note this turns off non-fatal warnings
-# set_log_level(ERROR)
-
-
-# # @ensemble_forward
-# def forward(
-# model,
-# mesh,
-# comm,
-# c,
-# excitations,
-# wavelet,
-# receivers,
-# source_num=0,
-# output=False,
-# **kwargs
-# ):
-# """Secord-order in time fully-explicit scheme
-# with implementation of a Perfectly Matched Layer (PML) using
-# CG FEM with or without higher order mass lumping (KMV type elements).
-
-# Parameters
-# ----------
-# model: Python `dictionary`
-# Contains model options and parameters
-# mesh: Firedrake.mesh object
-# The 2D/3D triangular mesh
-# comm: Firedrake.ensemble_communicator
-# The MPI communicator for parallelism
-# c: Firedrake.Function
-# The velocity model interpolated onto the mesh.
-# excitations: A list Firedrake.Functions
-# wavelet: array-like
-# Time series data that's injected at the source location.
-# receivers: A :class:`spyro.Receivers` object.
-# Contains the receiver locations and sparse interpolation methods.
-# source_num: `int`, optional
-# The source number you wish to simulate
-# output: `boolean`, optional
-# Whether or not to write results to pvd files.
-
-# Returns
-# -------
-# usol: list of Firedrake.Functions
-# The full field solution at `fspool` timesteps
-# usol_recv: array-like
-# The solution interpolated to the receivers at all timesteps
-
-# """
-
-# method = model["opts"]["method"]
-# degree = model["opts"]["degree"]
-# dim = model["opts"]["dimension"]
-# dt = model["timeaxis"]["dt"]
-# tf = model["timeaxis"]["tf"]
-# nspool = model["timeaxis"]["nspool"]
-# nt = int(tf / dt) # number of timesteps
-# excitations.current_source = source_num
-# params = set_params(method)
-# element = space.FE_method(mesh, method, degree)
-
-# V = FunctionSpace(mesh, element)
-
-# qr_x, qr_s, _ = quadrature.quadrature_rules(V)
-
-# if dim == 2:
-# z, x = SpatialCoordinate(mesh)
-# elif dim == 3:
-# z, x, y = SpatialCoordinate(mesh)
-
-# u = TrialFunction(V)
-# v = TestFunction(V)
-
-# u_nm1 = Function(V)
-# u_n = Function(V)
-# u_np1 = Function(V)
-
-# if output:
-# outfile = helpers.create_output_file("forward.pvd", comm, source_num)
-
-# t = 0.0
-# m = 1 / (c * c)
-# m1 = (
-# m * ((u - 2.0 * u_n + u_nm1) / Constant(dt**2)) * v * dx(scheme=qr_x)
-# )
-# a = dot(grad(u_n), grad(v)) * dx(scheme=qr_x) # explicit
-# f = Function(V)
-# nf = 0
-
-# if model["BCs"]["outer_bc"] == "non-reflective":
-# nf = c * ((u_n - u_nm1) / dt) * v * ds(scheme=qr_s)
-
-# h = CellSize(mesh)
-# FF = (
-# m1 + a + nf - (1 / (h / degree * h / degree)) * f * v * dx(scheme=qr_x)
-# )
-# X = Function(V)
-
-# lhs_ = lhs(FF)
-# rhs_ = rhs(FF)
-
-# problem = LinearVariationalProblem(lhs_, rhs_, X)
-# solver = LinearVariationalSolver(problem, solver_parameters=params)
-
-# usol_recv = []
-
-# P = FunctionSpace(receivers, "DG", 0)
-# interpolator = Interpolator(u_np1, P)
-# J0 = 0.0
-
-# for step in range(nt):
-# excitations.apply_source(f, wavelet[step])
-
-# solver.solve()
-# u_np1.assign(X)
-
-# rec = Function(P)
-# interpolator.interpolate(output=rec)
-
-# fwi = kwargs.get("fwi")
-# p_true_rec = kwargs.get("true_rec")
-
-# usol_recv.append(rec.dat.data)
-
-# if fwi:
-# J0 += calc_objective_func(rec, p_true_rec[step], step, dt, P)
-
-# if step % nspool == 0:
-# assert (
-# norm(u_n) < 1
-# ), "Numerical instability. Try reducing dt or building the mesh differently"
-# if output:
-# outfile.write(u_n, time=t, name="Pressure")
-# if t > 0:
-# helpers.display_progress(comm, t)
-
-# u_nm1.assign(u_n)
-# u_n.assign(u_np1)
-
-# t = step * float(dt)
-
-# if fwi:
-# return usol_recv, J0
-# else:
-# return usol_recv
-
-
-# def calc_objective_func(p_rec, p_true_rec, IT, dt, P):
-# true_rec = Function(P)
-# true_rec.dat.data[:] = p_true_rec
-# J = 0.5 * assemble(inner(true_rec - p_rec, true_rec - p_rec) * dx)
-# return J
-
-
-# def set_params(method):
-# if method == "KMV":
-# params = {"ksp_type": "preonly", "pc_type": "jacobi"}
-# elif (
-# method == "CG"
-# and mesh.ufl_cell() != quadrilateral
-# and mesh.ufl_cell() != hexahedron
-# ):
-# params = {"ksp_type": "cg", "pc_type": "jacobi"}
-# elif method == "CG" and (
-# mesh.ufl_cell() == quadrilateral or mesh.ufl_cell() == hexahedron
-# ):
-# params = {"ksp_type": "preonly", "pc_type": "jacobi"}
-# else:
-# raise ValueError("method is not yet supported")
-
-# return params
diff --git a/spyro/solvers/gradient_old.py b/spyro/solvers/gradient_old.py
deleted file mode 100644
index 3bcde5a2..00000000
--- a/spyro/solvers/gradient_old.py
+++ /dev/null
@@ -1,210 +0,0 @@
-# import firedrake as fire
-# from firedrake import dx, ds, Constant, grad, inner, dot
-# from firedrake.assemble import create_assembly_callable
-
-# from ..domains import quadrature, space
-# from ..pml import damping
-# from ..io import ensemble_gradient
-# from . import helpers
-
-# # Note this turns off non-fatal warnings
-# # set_log_level(ERROR)
-
-# __all__ = ["gradient"]
-
-
-# def gauss_lobatto_legendre_line_rule(degree):
-# fiat_make_rule = FIAT.quadrature.GaussLobattoLegendreQuadratureLineRule
-# fiat_rule = fiat_make_rule(FIAT.ufc_simplex(1), degree + 1)
-# finat_ps = finat.point_set.GaussLobattoLegendrePointSet
-# finat_qr = finat.quadrature.QuadratureRule
-# return finat_qr(finat_ps(fiat_rule.get_points()), fiat_rule.get_weights())
-
-
-# # 3D
-# def gauss_lobatto_legendre_cube_rule(dimension, degree):
-# make_tensor_rule = finat.quadrature.TensorProductQuadratureRule
-# result = gauss_lobatto_legendre_line_rule(degree)
-# for _ in range(1, dimension):
-# line_rule = gauss_lobatto_legendre_line_rule(degree)
-# result = make_tensor_rule([result, line_rule])
-# return result
-
-
-# @ensemble_gradient
-# def gradient(
-# model,
-# mesh,
-# comm,
-# c,
-# receivers,
-# guess,
-# residual,
-# output=False,
-# save_adjoint=False,
-# ):
-# """Discrete adjoint with secord-order in time fully-explicit timestepping scheme
-# with implementation of a Perfectly Matched Layer (PML) using
-# CG FEM with or without higher order mass lumping (KMV type elements).
-
-# Parameters
-# ----------
-# model: Python `dictionary`
-# Contains model options and parameters
-# mesh: Firedrake.mesh object
-# The 2D/3D triangular mesh
-# comm: Firedrake.ensemble_communicator
-# The MPI communicator for parallelism
-# c: Firedrake.Function
-# The velocity model interpolated onto the mesh nodes.
-# receivers: A :class:`spyro.Receivers` object.
-# Contains the receiver locations and sparse interpolation methods.
-# guess: A list of Firedrake functions
-# Contains the forward wavefield at a set of timesteps
-# residual: array-like [timesteps][receivers]
-# The difference between the observed and modeled data at
-# the receivers
-# output: boolean
-# optional, write the adjoint to disk (only for debugging)
-# save_adjoint: A list of Firedrake functions
-# Contains the adjoint at all timesteps
-
-# Returns
-# -------
-# dJdc_local: A Firedrake.Function containing the gradient of
-# the functional w.r.t. `c`
-# adjoint: Optional, a list of Firedrake functions containing the adjoint
-
-# """
-
-# method = model["opts"]["method"]
-# degree = model["opts"]["degree"]
-# dimension = model["opts"]["dimension"]
-# dt = model["timeaxis"]["dt"]
-# tf = model["timeaxis"]["tf"]
-# nspool = model["timeaxis"]["nspool"]
-# fspool = model["timeaxis"]["fspool"]
-
-# params = {"ksp_type": "cg", "pc_type": "jacobi"}
-
-# element = fire.FiniteElement(
-# method, mesh.ufl_cell(), degree=degree, variant="spectral"
-# )
-
-# V = fire.FunctionSpace(mesh, element)
-
-# qr_x = gauss_lobatto_legendre_cube_rule(dimension=dimension, degree=degree)
-# qr_s = gauss_lobatto_legendre_cube_rule(
-# dimension=(dimension - 1), degree=degree
-# )
-
-# nt = int(tf / dt) # number of timesteps
-
-# dJ = fire.Function(V, name="gradient")
-
-# # typical CG in N-d
-# u = fire.TrialFunction(V)
-# v = fire.TestFunction(V)
-
-# u_nm1 = fire.Function(V)
-# u_n = fire.Function(V)
-# u_np1 = fire.Function(V)
-
-# if output:
-# outfile = helpers.create_output_file("adjoint.pvd", comm, 0)
-
-# t = 0.0
-
-# # -------------------------------------------------------
-# m1 = ((u - 2.0 * u_n + u_nm1) / Constant(dt**2)) * v * dx(scheme=qr_x)
-# a = c * c * dot(grad(u_n), grad(v)) * dx(scheme=qr_x) # explicit
-
-# lhs1 = m1
-# rhs1 = -a
-
-# X = fire.Function(V)
-# B = fire.Function(V)
-
-# A = fire.assemble(lhs1, mat_type="matfree")
-# solver = fire.LinearSolver(A, solver_parameters=params)
-
-# # Define gradient problem
-# m_u = fire.TrialFunction(V)
-# m_v = fire.TestFunction(V)
-# mgrad = m_u * m_v * dx(rule=qr_x)
-
-# uuadj = fire.Function(V) # auxiliarly function for the gradient compt.
-# uufor = fire.Function(V) # auxiliarly function for the gradient compt.
-
-# ffG = 2.0 * c * dot(grad(uuadj), grad(uufor)) * m_v * dx(scheme=qr_x)
-
-# lhsG = mgrad
-# rhsG = ffG
-
-# gradi = fire.Function(V)
-# grad_prob = fire.LinearVariationalProblem(lhsG, rhsG, gradi)
-
-# if method == "KMV":
-# grad_solver = fire.LinearVariationalSolver(
-# grad_prob,
-# solver_parameters={
-# "ksp_type": "preonly",
-# "pc_type": "jacobi",
-# "mat_type": "matfree",
-# },
-# )
-# elif method == "CG":
-# grad_solver = fire.LinearVariationalSolver(
-# grad_prob,
-# solver_parameters={
-# "mat_type": "matfree",
-# },
-# )
-
-# assembly_callable = create_assembly_callable(rhs1, tensor=B)
-
-# rhs_forcing = fire.Function(V) # forcing term
-# if save_adjoint:
-# adjoint = [
-# fire.Function(V, name="adjoint_pressure") for t in range(nt)
-# ]
-# for step in range(nt - 1, -1, -1):
-# t = step * float(dt)
-# rhs_forcing.assign(0.0)
-# # Solver - main equation - (I)
-# # B = assemble(rhs_, tensor=B)
-# assembly_callable()
-
-# f = receivers.apply_receivers_as_source(rhs_forcing, residual, step)
-# # add forcing term to solve scalar pressure
-# B0 = B.sub(0)
-# B0 += f
-
-# # AX=B --> solve for X = B/Aˆ-1
-# solver.solve(X, B)
-
-# u_np1.assign(X)
-
-# # only compute for snaps that were saved
-# if step % fspool == 0:
-# # compute the gradient increment
-# uuadj.assign(u_np1)
-# uufor.assign(guess.pop())
-
-# grad_solver.solve()
-# dJ += gradi
-
-# u_nm1.assign(u_n)
-# u_n.assign(u_np1)
-
-# if step % nspool == 0:
-# if output:
-# outfile.write(u_n, time=t)
-# if save_adjoint:
-# adjoint.append(u_n)
-# helpers.display_progress(comm, t)
-
-# if save_adjoint:
-# return dJ, adjoint
-# else:
-# return dJ
diff --git a/spyro/solvers/time_integration.py b/spyro/solvers/time_integration.py
index 889fd834..1a846a39 100644
--- a/spyro/solvers/time_integration.py
+++ b/spyro/solvers/time_integration.py
@@ -1,5 +1,4 @@
from .time_integration_central_difference import central_difference
-from .time_integration_central_difference import mixed_space_central_difference
from .time_integration_central_difference import central_difference_MMS
@@ -11,13 +10,7 @@ def time_integrator(Wave_object, source_id=0):
def time_integrator_ricker(Wave_object, source_id=0):
- if Wave_object.time_integrator == "central_difference":
- return central_difference(Wave_object, source_id=source_id)
- elif Wave_object.time_integrator == "mixed_space_central_difference":
- return mixed_space_central_difference(Wave_object, source_id=source_id)
- else:
- raise ValueError("The time integrator specified is not implemented yet")
-
+ return central_difference(Wave_object, source_id=source_id)
def time_integrator_mms(Wave_object, source_id=0):
if Wave_object.time_integrator == "central_difference":
diff --git a/spyro/solvers/time_integration_central_difference.py b/spyro/solvers/time_integration_central_difference.py
index 603f92c5..0831a4ae 100644
--- a/spyro/solvers/time_integration_central_difference.py
+++ b/spyro/solvers/time_integration_central_difference.py
@@ -6,13 +6,13 @@
from .. import utils
-def central_difference(Wave_object, source_id=0):
+def central_difference(wave, source_id=0):
"""
Perform central difference time integration for wave propagation.
Parameters:
-----------
- Wave_object: Spyro object
+ wave: Spyro object
The Wave object containing the necessary data and parameters.
source_id: int (optional)
The ID of the source being propagated. Defaults to 0.
@@ -22,192 +22,72 @@ def central_difference(Wave_object, source_id=0):
tuple:
A tuple containing the forward solution and the receiver output.
"""
- excitations = Wave_object.sources
- excitations.current_source = source_id
- receivers = Wave_object.receivers
- comm = Wave_object.comm
- temp_filename = Wave_object.forward_output_file
+ wave.sources.current_source = source_id
- filename, file_extension = temp_filename.split(".")
+ filename, file_extension = wave.forward_output_file.split(".")
output_filename = filename + "sn" + str(source_id) + "." + file_extension
- if Wave_object.forward_output:
- parallel_print(f"Saving output in: {output_filename}", Wave_object.comm)
-
- output = fire.File(output_filename, comm=comm.comm)
- comm.comm.barrier()
-
- X = fire.Function(Wave_object.function_space)
-
- final_time = Wave_object.final_time
- dt = Wave_object.dt
- t = Wave_object.current_time
- nt = int((final_time - t) / dt) + 1 # number of timesteps
-
- u_nm1 = Wave_object.u_nm1
- u_n = Wave_object.u_n
- u_np1 = fire.Function(Wave_object.function_space)
-
- rhs_forcing = fire.Function(Wave_object.function_space)
+ if wave.forward_output:
+ parallel_print(f"Saving output in: {output_filename}", wave.comm)
+
+ output = fire.File(output_filename, comm=wave.comm.comm)
+ wave.comm.comm.barrier()
+
+ t = wave.current_time
+ nt = int(wave.final_time / wave.dt) + 1 # number of timesteps
+
+ rhs_forcing = fire.Function(wave.function_space)
usol = [
- fire.Function(Wave_object.function_space, name="pressure")
+ fire.Function(wave.function_space, name=wave.get_function_name())
for t in range(nt)
- if t % Wave_object.gradient_sampling_frequency == 0
+ if t % wave.gradient_sampling_frequency == 0
]
usol_recv = []
save_step = 0
- B = Wave_object.B
- rhs = Wave_object.rhs
-
for step in range(nt):
rhs_forcing.assign(0.0)
- B = fire.assemble(rhs, tensor=B)
- f = excitations.apply_source(rhs_forcing, Wave_object.wavelet[step])
- B0 = B.sub(0)
+ fire.assemble(wave.rhs, tensor=wave.B)
+ f = wave.sources.apply_source(rhs_forcing, wave.wavelet[step])
+ B0 = wave.B.sub(0)
B0 += f
- Wave_object.solver.solve(X, B)
+
+ wave.solver.solve(wave.next_vstate, wave.B)
- u_np1.assign(X)
+ wave.prev_vstate = wave.vstate
+ wave.vstate = wave.next_vstate
- usol_recv.append(
- Wave_object.receivers.interpolate(u_np1.dat.data_ro_with_halos[:])
- )
+ usol_recv.append(wave.get_receivers_output())
- if step % Wave_object.gradient_sampling_frequency == 0:
- usol[save_step].assign(u_np1)
+ if step % wave.gradient_sampling_frequency == 0:
+ usol[save_step].assign(wave.get_function())
save_step += 1
- if (step - 1) % Wave_object.output_frequency == 0:
+ if (step - 1) % wave.output_frequency == 0:
assert (
- fire.norm(u_n) < 1
+ fire.norm(wave.get_function()) < 1
), "Numerical instability. Try reducing dt or building the " \
"mesh differently"
- if Wave_object.forward_output:
- output.write(u_n, time=t, name="Pressure")
+ if wave.forward_output:
+ output.write(wave.get_function(), time=t,
+ name=wave.get_function_name())
- helpers.display_progress(Wave_object.comm, t)
+ helpers.display_progress(wave.comm, t)
- u_nm1.assign(u_n)
- u_n.assign(u_np1)
-
- t = step * float(dt)
-
- Wave_object.current_time = t
- helpers.display_progress(Wave_object.comm, t)
+ t = step * float(wave.dt)
+
+ wave.current_time = t
+ helpers.display_progress(wave.comm, t)
usol_recv = helpers.fill(
- usol_recv, receivers.is_local, nt, receivers.number_of_points
+ usol_recv, wave.receivers.is_local, nt, wave.receivers.number_of_points
)
- usol_recv = utils.utils.communicate(usol_recv, comm)
- Wave_object.receivers_output = usol_recv
+ usol_recv = utils.utils.communicate(usol_recv, wave.comm)
+ wave.receivers_output = usol_recv
- Wave_object.forward_solution = usol
- Wave_object.forward_solution_receivers = usol_recv
+ wave.forward_solution = usol
+ wave.forward_solution_receivers = usol_recv
return usol, usol_recv
-
-def mixed_space_central_difference(Wave_object, source_id=0):
- """
- Performs central difference time integration for wave propagation.
- Solves for a mixed space formulation, for function X. For correctly
- outputing pressure, order the mixed function space so that the space
- pressure lives in is first.
-
- Parameters:
- -----------
- Wave_object: Spyro object
- The Wave object containing the necessary data and parameters.
- source_id: int (optional)
- The ID of the source being propagated. Defaults to 0.
-
- Returns:
- --------
- tuple:
- A tuple containing the forward solution and the receiver output.
- """
- excitations = Wave_object.sources
- excitations.current_source = source_id
- receivers = Wave_object.receivers
- comm = Wave_object.comm
- temp_filename = Wave_object.forward_output_file
- filename, file_extension = temp_filename.split(".")
- output_filename = filename + "sn" + str(source_id) + "." + file_extension
- if Wave_object.forward_output:
- parallel_print(f"Saving output in: {output_filename}", Wave_object.comm)
-
- output = fire.File(output_filename, comm=comm.comm)
- comm.comm.barrier()
-
- final_time = Wave_object.final_time
- dt = Wave_object.dt
- t = Wave_object.current_time
- nt = int(final_time / dt) + 1 # number of timesteps
-
- X = Wave_object.X
- X_n = Wave_object.X_n
- X_nm1 = Wave_object.X_nm1
-
- rhs_forcing = fire.Function(Wave_object.function_space)
- usol = [
- fire.Function(Wave_object.function_space, name="pressure")
- for t in range(nt)
- if t % Wave_object.gradient_sampling_frequency == 0
- ]
- usol_recv = []
- save_step = 0
- B = Wave_object.B
- rhs_ = Wave_object.rhs
-
- for step in range(nt):
- rhs_forcing.assign(0.0)
- B = fire.assemble(rhs_, tensor=B)
- f = excitations.apply_source(rhs_forcing, Wave_object.wavelet[step])
- B0 = B.sub(0)
- B0 += f
- Wave_object.solver.solve(X, B)
-
- X_np1 = X
-
- X_nm1.assign(X_n)
- X_n.assign(X_np1)
-
- usol_recv.append(
- Wave_object.receivers.interpolate(
- X_np1.dat.data_ro_with_halos[0][:]
- )
- )
-
- if step % Wave_object.gradient_sampling_frequency == 0:
- usol[save_step].assign(X_np1.sub(0))
- save_step += 1
-
- if (step - 1) % Wave_object.output_frequency == 0:
- assert (
- fire.norm(X_np1.sub(0)) < 1
- ), "Numerical instability. Try reducing dt or building the " \
- "mesh differently"
- if Wave_object.forward_output:
- output.write(X_np1.sub(0), time=t, name="Pressure")
-
- helpers.display_progress(comm, t)
-
- t = step * float(dt)
-
- Wave_object.current_time = t
- helpers.display_progress(Wave_object.comm, t)
-
- usol_recv = helpers.fill(
- usol_recv, receivers.is_local, nt, receivers.number_of_points
- )
- usol_recv = utils.utils.communicate(usol_recv, comm)
- Wave_object.receivers_output = usol_recv
-
- Wave_object.forward_solution = usol
- Wave_object.forward_solution_receivers = usol_recv
-
- return usol, usol_recv
-
-
def central_difference_MMS(Wave_object, source_id=0):
"""Propagates the wave forward in time.
Currently uses central differences.
diff --git a/spyro/solvers/wave.py b/spyro/solvers/wave.py
index be2baf5e..b0f9b2db 100644
--- a/spyro/solvers/wave.py
+++ b/spyro/solvers/wave.py
@@ -290,3 +290,51 @@ def set_last_solve_as_real_shot_record(self):
if self.current_time == 0.0:
raise ValueError("No previous solve to set as real shot record.")
self.real_shot_record = self.forward_solution_receivers
+
+ @abstractmethod
+ def _set_vstate(self, vstate):
+ pass
+
+ @abstractmethod
+ def _get_vstate(self):
+ pass
+
+ @abstractmethod
+ def _set_prev_vstate(self, vstate):
+ pass
+
+ @abstractmethod
+ def _get_prev_vstate(self):
+ pass
+
+ @abstractmethod
+ def _set_next_vstate(self, vstate):
+ pass
+
+ @abstractmethod
+ def _get_next_vstate(self):
+ pass
+
+ # Managed attributes to access state variables in current, previous and next iteration
+ vstate = property(fget=lambda self: self._get_vstate(),
+ fset=lambda self, value: self._set_vstate(value))
+ prev_vstate = property(fget=lambda self: self._get_prev_vstate(),
+ fset=lambda self, value: self._set_prev_vstate(value))
+ next_vstate = property(fget=lambda self: self._get_next_vstate(),
+ fset=lambda self, value: self._set_next_vstate(value))
+
+ @abstractmethod
+ def get_receivers_output(self):
+ pass
+
+ @abstractmethod
+ def get_function(self):
+ '''Returns the function (e.g., pressure or displacement) associated with
+ the wave object without additional variables (e.g., PML variables)'''
+ pass
+
+ @abstractmethod
+ def get_function_name(self):
+ '''Returns the string representing the function of the wave object
+ (e.g., "pressure" or "displacement")'''
+ pass
\ No newline at end of file
diff --git a/spyro/utils/typing.py b/spyro/utils/typing.py
new file mode 100644
index 00000000..b393ccb6
--- /dev/null
+++ b/spyro/utils/typing.py
@@ -0,0 +1,6 @@
+def override(func):
+ '''
+ This decorator should be replaced by typing.override when Python
+ version is updated to 3.12
+ '''
+ return func
\ No newline at end of file
From baa852d0b32127de831236ac0971246478bdd035 Mon Sep 17 00:00:00 2001
From: Eduardo Moscatelli de Souza <5752216+SouzaEM@users.noreply.github.com>
Date: Tue, 2 Jul 2024 13:48:43 -0300
Subject: [PATCH 4/5] Refactor time integration code to merge
'central_difference' and 'central_difference_MMS' function into a unique
function
---
.../acoustic_solver_construction_no_pml.py | 7 +-
spyro/solvers/mms_acoustic.py | 19 ++-
spyro/solvers/time_integration.py | 15 --
.../time_integration_central_difference.py | 129 ++----------------
spyro/solvers/wave.py | 9 +-
test/test_MMS.py | 2 +-
6 files changed, 41 insertions(+), 140 deletions(-)
diff --git a/spyro/solvers/acoustic_solver_construction_no_pml.py b/spyro/solvers/acoustic_solver_construction_no_pml.py
index 252c0180..2bd93c12 100644
--- a/spyro/solvers/acoustic_solver_construction_no_pml.py
+++ b/spyro/solvers/acoustic_solver_construction_no_pml.py
@@ -35,9 +35,14 @@ def construct_solver_or_matrix_no_pml(Wave_object):
)
a = dot(grad(u_n), grad(v)) * dx(scheme=quad_rule) # explicit
+ le = 0
+ q = Wave_object.source_expression
+ if q is not None:
+ le = q * v * dx(scheme=quad_rule)
+
B = fire.Function(V)
- form = m1 + a
+ form = m1 + a - le
lhs = fire.lhs(form)
rhs = fire.rhs(form)
Wave_object.lhs = lhs
diff --git a/spyro/solvers/mms_acoustic.py b/spyro/solvers/mms_acoustic.py
index ccbfb28c..6cfebd3f 100644
--- a/spyro/solvers/mms_acoustic.py
+++ b/spyro/solvers/mms_acoustic.py
@@ -1,6 +1,6 @@
import firedrake as fire
from .acoustic_wave import AcousticWave
-
+from ..utils.typing import override
class AcousticWaveMMS(AcousticWave):
"""Class for solving the acoustic wave equation in 2D or 3D using
@@ -10,14 +10,21 @@ class AcousticWaveMMS(AcousticWave):
"""
def matrix_building(self):
+ self.mms_source_in_space()
+ self.q_t = fire.Constant(0)
+ self.source_expression = self.q_t * self.q_xy
+
super().matrix_building()
lhs = self.lhs
bcs = fire.DirichletBC(self.function_space, 0.0, "on_boundary")
A = fire.assemble(lhs, bcs=bcs, mat_type="matfree")
- self.mms_source_in_space()
self.solver = fire.LinearSolver(
A, solver_parameters=self.solver_parameters
)
+ dt = self.dt
+ t = self.current_time
+ self.u_nm1.assign(self.analytical_solution(t - 2 * dt))
+ self.u_n.assign(self.analytical_solution(t - dt))
def mms_source_in_space(self):
V = self.function_space
@@ -45,10 +52,6 @@ def mms_source_in_space(self):
# self.q_xy.interpolate(sin(pi*x)*sin(pi*y))
- def mms_source_in_time(self, t):
- # return fire.Constant(2*pi**2*t**2 + 2.0)
- return fire.Constant(2 * t)
-
def analytical_solution(self, t):
self.analytical = fire.Function(self.function_space)
x = self.mesh_z
@@ -66,3 +69,7 @@ def analytical_solution(self, t):
# self.analytical.assign(analytical)
return self.analytical
+
+ @override
+ def update_source_expression(self, t):
+ self.q_t.assign(2*t)
diff --git a/spyro/solvers/time_integration.py b/spyro/solvers/time_integration.py
index 1a846a39..45573f12 100644
--- a/spyro/solvers/time_integration.py
+++ b/spyro/solvers/time_integration.py
@@ -1,19 +1,4 @@
from .time_integration_central_difference import central_difference
-from .time_integration_central_difference import central_difference_MMS
-
def time_integrator(Wave_object, source_id=0):
- if Wave_object.source_type == "ricker":
- return time_integrator_ricker(Wave_object, source_id=source_id)
- elif Wave_object.source_type == "MMS":
- return time_integrator_mms(Wave_object, source_id=source_id)
-
-
-def time_integrator_ricker(Wave_object, source_id=0):
return central_difference(Wave_object, source_id=source_id)
-
-def time_integrator_mms(Wave_object, source_id=0):
- if Wave_object.time_integrator == "central_difference":
- return central_difference_MMS(Wave_object, source_id=source_id)
- else:
- raise ValueError("The time integrator specified is not implemented yet")
diff --git a/spyro/solvers/time_integration_central_difference.py b/spyro/solvers/time_integration_central_difference.py
index 0831a4ae..3bb4541b 100644
--- a/spyro/solvers/time_integration_central_difference.py
+++ b/spyro/solvers/time_integration_central_difference.py
@@ -1,5 +1,4 @@
import firedrake as fire
-from firedrake import Constant, dx, dot, grad
from ..io.basicio import parallel_print
from . import helpers
@@ -22,7 +21,9 @@ def central_difference(wave, source_id=0):
tuple:
A tuple containing the forward solution and the receiver output.
"""
- wave.sources.current_source = source_id
+ if wave.sources is not None:
+ wave.sources.current_source = source_id
+ rhs_forcing = fire.Function(wave.function_space)
filename, file_extension = wave.forward_output_file.split(".")
output_filename = filename + "sn" + str(source_id) + "." + file_extension
@@ -35,7 +36,6 @@ def central_difference(wave, source_id=0):
t = wave.current_time
nt = int(wave.final_time / wave.dt) + 1 # number of timesteps
- rhs_forcing = fire.Function(wave.function_space)
usol = [
fire.Function(wave.function_space, name=wave.get_function_name())
for t in range(nt)
@@ -44,11 +44,16 @@ def central_difference(wave, source_id=0):
usol_recv = []
save_step = 0
for step in range(nt):
- rhs_forcing.assign(0.0)
+ # Basic way of applying sources
+ wave.update_source_expression(t)
fire.assemble(wave.rhs, tensor=wave.B)
- f = wave.sources.apply_source(rhs_forcing, wave.wavelet[step])
- B0 = wave.B.sub(0)
- B0 += f
+
+ # More efficient way of applying sources
+ if wave.sources is not None:
+ rhs_forcing.assign(0.0)
+ f = wave.sources.apply_source(rhs_forcing, wave.wavelet[step])
+ B0 = wave.B.sub(0)
+ B0 += f
wave.solver.solve(wave.next_vstate, wave.B)
@@ -86,112 +91,4 @@ def central_difference(wave, source_id=0):
wave.forward_solution = usol
wave.forward_solution_receivers = usol_recv
- return usol, usol_recv
-
-def central_difference_MMS(Wave_object, source_id=0):
- """Propagates the wave forward in time.
- Currently uses central differences.
-
- Parameters:
- -----------
- dt: Python 'float' (optional)
- Time step to be used explicitly. If not mentioned uses the default,
- that was estabilished in the model_parameters.
- final_time: Python 'float' (optional)
- Time which simulation ends. If not mentioned uses the default,
- that was estabilished in the model_parameters.
- """
- receivers = Wave_object.receivers
- comm = Wave_object.comm
- temp_filename = Wave_object.forward_output_file
- filename, file_extension = temp_filename.split(".")
- output_filename = filename + "sn_mms_" + "." + file_extension
- if Wave_object.forward_output:
- print(f"Saving output in: {output_filename}", flush=True)
-
- output = fire.File(output_filename, comm=comm.comm)
- comm.comm.barrier()
-
- X = fire.Function(Wave_object.function_space)
-
- final_time = Wave_object.final_time
- dt = Wave_object.dt
- t = Wave_object.current_time
- nt = int((final_time - t) / dt) + 1 # number of timesteps
-
- u_nm1 = Wave_object.u_nm1
- u_n = Wave_object.u_n
- u_nm1.assign(Wave_object.analytical_solution(t - 2 * dt))
- u_n.assign(Wave_object.analytical_solution(t - dt))
- u_np1 = fire.Function(Wave_object.function_space, name="pressure t +dt")
- u = fire.TrialFunction(Wave_object.function_space)
- v = fire.TestFunction(Wave_object.function_space)
-
- usol = [
- fire.Function(Wave_object.function_space, name="pressure")
- for t in range(nt)
- if t % Wave_object.gradient_sampling_frequency == 0
- ]
- usol_recv = []
- save_step = 0
- B = Wave_object.B
- rhs = Wave_object.rhs
- quad_rule = Wave_object.quadrature_rule
-
- q_xy = Wave_object.q_xy
-
- for step in range(nt):
- q = q_xy * Wave_object.mms_source_in_time(t)
- m1 = (
- 1
- / (Wave_object.c * Wave_object.c)
- * ((u - 2.0 * u_n + u_nm1) / Constant(dt**2))
- * v
- * dx(scheme=quad_rule)
- )
- a = dot(grad(u_n), grad(v)) * dx(scheme=quad_rule)
- le = q * v * dx(scheme=quad_rule)
-
- form = m1 + a - le
- rhs = fire.rhs(form)
-
- B = fire.assemble(rhs, tensor=B)
-
- Wave_object.solver.solve(X, B)
-
- u_np1.assign(X)
-
- usol_recv.append(
- Wave_object.receivers.interpolate(u_np1.dat.data_ro_with_halos[:])
- )
-
- if step % Wave_object.gradient_sampling_frequency == 0:
- usol[save_step].assign(u_np1)
- save_step += 1
-
- if (step - 1) % Wave_object.output_frequency == 0:
- assert (
- fire.norm(u_n) < 1
- ), "Numerical instability. Try reducing dt or building the " \
- "mesh differently"
- if Wave_object.forward_output:
- output.write(u_n, time=t, name="Pressure")
- if t > 0:
- helpers.display_progress(Wave_object.comm, t)
-
- u_nm1.assign(u_n)
- u_n.assign(u_np1)
-
- t = step * float(dt)
-
- Wave_object.current_time = t
- helpers.display_progress(Wave_object.comm, t)
- Wave_object.analytical_solution(t)
-
- usol_recv = helpers.fill(
- usol_recv, receivers.is_local, nt, receivers.number_of_points
- )
- usol_recv = utils.utils.communicate(usol_recv, comm)
- Wave_object.receivers_output = usol_recv
-
- return usol, usol_recv
+ return usol, usol_recv
\ No newline at end of file
diff --git a/spyro/solvers/wave.py b/spyro/solvers/wave.py
index b0f9b2db..293b29d2 100644
--- a/spyro/solvers/wave.py
+++ b/spyro/solvers/wave.py
@@ -87,6 +87,8 @@ def __init__(self, dictionary=None, comm=None):
)
else:
warnings.warn("No mesh found. Please define a mesh.")
+ # Expression to define sources through UFL (less efficient)
+ self.source_expression = None
@abstractmethod
def forward_solve(self):
@@ -337,4 +339,9 @@ def get_function(self):
def get_function_name(self):
'''Returns the string representing the function of the wave object
(e.g., "pressure" or "displacement")'''
- pass
\ No newline at end of file
+ pass
+
+ def update_source_expression(self, t):
+ '''Update the source expression during wave propagation. This method must be
+ implemented only by subclasses that make use of the source term'''
+ pass
diff --git a/test/test_MMS.py b/test/test_MMS.py
index ab2cb44c..e44b90d1 100644
--- a/test/test_MMS.py
+++ b/test/test_MMS.py
@@ -35,7 +35,7 @@ def run_solve(model):
Wave_obj.set_initial_velocity_model(expression="1 + sin(pi*-z)*sin(pi*x)")
Wave_obj.forward_solve()
- u_an = Wave_obj.analytical
+ u_an = Wave_obj.analytical_solution(Wave_obj.current_time)
u_num = Wave_obj.u_n
return errornorm(u_num, u_an)
From db5b048a3285b32d08d7720f2833e97594f87a90 Mon Sep 17 00:00:00 2001
From: Eduardo Moscatelli de Souza <5752216+SouzaEM@users.noreply.github.com>
Date: Tue, 2 Jul 2024 14:07:44 -0300
Subject: [PATCH 5/5] Remove correction of time integration configuration for
ABC
---
spyro/io/model_parameters.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/spyro/io/model_parameters.py b/spyro/io/model_parameters.py
index 89e86037..f24ac84c 100644
--- a/spyro/io/model_parameters.py
+++ b/spyro/io/model_parameters.py
@@ -383,12 +383,6 @@ def _sanitize_absorving_boundary_condition(self):
self.abc_R = BL_obj.abc_R
self.abc_pad_length = BL_obj.abc_pad_length
self.abc_boundary_layer_type = BL_obj.abc_boundary_layer_type
- if self.abc_status:
- self._correct_time_integrator_for_abc()
-
- def _correct_time_integrator_for_abc(self):
- if self.time_integrator == "central_difference":
- self.time_integrator = "mixed_space_central_difference"
def _sanitize_output(self):
# default_dictionary["visualization"] = {