From the plot above, we can notice two things. First, it is rotated by 90 degrees. In cases like this, the saved PNG image will be in the correct orientation. However, the plotted image in this notebook is rotated because the velocity model we used, commonly with segy files, has the Z-axis before the X-axis in data input.
-
Another observation is that, unlike in the simple forward tutorial, we looked at the experiment layout after running the forward solve. For memory purposes, a 2D or 3D velocity file is only actually interpolated into our domain when it is necessary for another method of the Wave object class. If you need to force the interpolation sooner, call the _get_initial_velocity_model() method.
+
Another observation is that, unlike in the simple forward tutorial, we looked at the experiment layout after running the forward solve. For memory purposes, a 2D or 3D velocity file is only actually interpolated into our domain when it is necessary for another method of the Wave object class. If you need to force the interpolation sooner, call the _initialize_model_parameters() method.
It is also important to note that even though receivers look like a line, they are actually located in points, which can be visible by zooming into the image, not coinciding with nodes.
diff --git a/notebook_tutorials/simple_forward_with_overthrust.ipynb b/notebook_tutorials/simple_forward_with_overthrust.ipynb
index f51e3c50..69d76d55 100644
--- a/notebook_tutorials/simple_forward_with_overthrust.ipynb
+++ b/notebook_tutorials/simple_forward_with_overthrust.ipynb
@@ -397,7 +397,7 @@
"source": [
"From the plot above, we can notice two things. First, it is rotated by 90 degrees. In cases like this, the saved PNG image will be in the correct orientation. However, the plotted image in this notebook is rotated because the velocity model we used, commonly with segy files, has the Z-axis before the X-axis in data input.\n",
"\n",
- "Another observation is that, unlike in the **simple forward tutorial**, we looked at the experiment layout after running the forward solve. For memory purposes, a 2D or 3D velocity file is only actually interpolated into our domain when it is necessary for another method of the Wave object class. If you need to force the interpolation sooner, call the _get_initial_velocity_model() method.\n",
+ "Another observation is that, unlike in the **simple forward tutorial**, we looked at the experiment layout after running the forward solve. For memory purposes, a 2D or 3D velocity file is only actually interpolated into our domain when it is necessary for another method of the Wave object class. If you need to force the interpolation sooner, call the _initialize_model_parameters() method.\n",
"\n",
"It is also important to note that even though receivers look like a line, they are actually located in points, which can be visible by zooming into the image, not coinciding with nodes."
]
diff --git a/spyro/examples/camembert.py b/spyro/examples/camembert.py
index 2abb71b0..0fc6b69b 100644
--- a/spyro/examples/camembert.py
+++ b/spyro/examples/camembert.py
@@ -156,3 +156,7 @@ def _camembert_velocity_model(self):
)
self.set_initial_velocity_model(conditional=cond, dg_velocity_model=False)
return None
+
+if __name__ == "__main__":
+ wave = Camembert_acoustic()
+ wave.forward_solve()
diff --git a/spyro/examples/rectangle.py b/spyro/examples/rectangle.py
index 8f2f6693..5257adf9 100644
--- a/spyro/examples/rectangle.py
+++ b/spyro/examples/rectangle.py
@@ -182,25 +182,6 @@ def multiple_layer_velocity_model(self, z_switch, layers):
# cond = fire.conditional(self.mesh_z > z_switch, layer1, layer2)
self.set_initial_velocity_model(conditional=cond)
-
-# class Rectangle(AcousticWave):
-# def __init__(self, model_dictionary=None, comm=None):
-# model_parameters = Rectangle_parameters(
-# dictionary=model_dictionary, comm=comm
-# )
-# super().__init__(
-# model_parameters=model_parameters, comm=model_parameters.comm
-# )
-# comm = self.comm
-# num_sources = self.number_of_sources
-# if comm.comm.rank == 0 and comm.ensemble_comm.rank == 0:
-# print(
-# "INFO: Distributing %d shot(s) across %d core(s). \
-# Each shot is using %d cores"
-# % (
-# num_sources,
-# fire.COMM_WORLD.size,
-# fire.COMM_WORLD.size / comm.ensemble_comm.size,
-# ),
-# flush=True,
-# )
+if __name__ == "__main__":
+ wave = Rectangle_acoustic()
+ wave.forward_solve()
diff --git a/spyro/io/model_parameters.py b/spyro/io/model_parameters.py
index 1904dc55..367ef3b8 100644
--- a/spyro/io/model_parameters.py
+++ b/spyro/io/model_parameters.py
@@ -383,12 +383,6 @@ def _sanitize_absorving_boundary_condition(self):
self.abc_R = BL_obj.abc_R
self.abc_pad_length = BL_obj.abc_pad_length
self.abc_boundary_layer_type = BL_obj.abc_boundary_layer_type
- if self.abc_active:
- self._correct_time_integrator_for_abc()
-
- def _correct_time_integrator_for_abc(self):
- if self.time_integrator == "central_difference":
- self.time_integrator = "mixed_space_central_difference"
def _sanitize_output(self):
# default_dictionary["visualization"] = {
@@ -608,9 +602,9 @@ def _sanitize_optimization_and_velocity(self):
if "velocity_conditional" not in dictionary["synthetic_data"]:
self.velocity_model_type = None
warnings.warn(
- "No velocity model set initially. If using \
- user defined conditional or expression, please \
- input it in the Wave object."
+ "No velocity model set initially. If using " \
+ "user defined conditional or expression, please " \
+ "input it in the Wave object."
)
if "velocity_conditional" in dictionary["synthetic_data"]:
diff --git a/spyro/solvers/acoustic_solver_construction_no_pml.py b/spyro/solvers/acoustic_solver_construction_no_pml.py
index 252c0180..2bd93c12 100644
--- a/spyro/solvers/acoustic_solver_construction_no_pml.py
+++ b/spyro/solvers/acoustic_solver_construction_no_pml.py
@@ -35,9 +35,14 @@ def construct_solver_or_matrix_no_pml(Wave_object):
)
a = dot(grad(u_n), grad(v)) * dx(scheme=quad_rule) # explicit
+ le = 0
+ q = Wave_object.source_expression
+ if q is not None:
+ le = q * v * dx(scheme=quad_rule)
+
B = fire.Function(V)
- form = m1 + a
+ form = m1 + a - le
lhs = fire.lhs(form)
rhs = fire.rhs(form)
Wave_object.lhs = lhs
diff --git a/spyro/solvers/acoustic_wave.py b/spyro/solvers/acoustic_wave.py
index 3ec51d44..990c7ffc 100644
--- a/spyro/solvers/acoustic_wave.py
+++ b/spyro/solvers/acoustic_wave.py
@@ -14,7 +14,7 @@
from .backward_time_integration import (
backward_wave_propagator,
)
-
+from ..utils.typing import override
class AcousticWave(Wave):
def save_current_velocity_model(self, file_name=None):
@@ -40,7 +40,7 @@ def forward_solve(self):
if self.function_space is None:
self.force_rebuild_function_space()
- self._get_initial_velocity_model()
+ self._initialize_model_parameters()
self.c = self.initial_velocity_model
self.matrix_building()
self.wave_propagator()
@@ -68,6 +68,7 @@ def matrix_building(self):
self.trial_function = None
self.u_nm1 = None
self.u_n = None
+ self.u_np1 = fire.Function(self.function_space)
self.lhs = None
self.solver = None
self.rhs = None
@@ -81,6 +82,7 @@ def matrix_building(self):
self.X = None
self.X_n = None
self.X_nm1 = None
+ self.X_np1 = fire.Function(V * Z)
construct_solver_or_matrix_with_pml(self)
@ensemble_propagator
@@ -144,10 +146,94 @@ def reset_pressure(self):
self.u_n.assign(0.0)
except:
warnings.warn("No pressure to reset")
- if self.abc_active:
- try:
- self.X_n.assign(0.0)
- self.X_nm1.assign(0.0)
- except:
- warnings.warn("No mixed space pressure to reset")
-
+
+ @override
+ def _initialize_model_parameters(self):
+ if self.initial_velocity_model is not None:
+ return None
+
+ if self.initial_velocity_model_file is None:
+ raise ValueError("No velocity model or velocity file to load.")
+
+ if self.initial_velocity_model_file.endswith(".segy"):
+ vp_filename, vp_filetype = os.path.splitext(
+ self.initial_velocity_model_file
+ )
+ warnings.warn("Converting segy file to hdf5")
+ write_velocity_model(
+ self.initial_velocity_model_file, ofname=vp_filename
+ )
+ self.initial_velocity_model_file = vp_filename + ".hdf5"
+
+ if self.initial_velocity_model_file.endswith((".hdf5", ".h5")):
+ self.initial_velocity_model = interpolate(
+ self,
+ self.initial_velocity_model_file,
+ self.function_space.sub(0),
+ )
+
+ if self.debug_output:
+ fire.File("initial_velocity_model.pvd").write(
+ self.initial_velocity_model, name="velocity"
+ )
+
+ @override
+ def _set_vstate(self, vstate):
+ if self.abc_boundary_layer_type == "PML":
+ self.X_n.assign(vstate)
+ else:
+ self.u_n.assign(vstate)
+
+ @override
+ def _get_vstate(self):
+ if self.abc_boundary_layer_type == "PML":
+ return self.X_n
+ else:
+ return self.u_n
+
+ @override
+ def _set_prev_vstate(self, vstate):
+ if self.abc_boundary_layer_type == "PML":
+ self.X_nm1.assign(vstate)
+ else:
+ self.u_nm1.assign(vstate)
+
+ @override
+ def _get_prev_vstate(self):
+ if self.abc_boundary_layer_type == "PML":
+ return self.X_nm1
+ else:
+ return self.u_nm1
+
+ @override
+ def _set_next_vstate(self, vstate):
+ if self.abc_boundary_layer_type == "PML":
+ self.X_np1.assign(vstate)
+ else:
+ self.u_np1.assign(vstate)
+
+ @override
+ def _get_next_vstate(self):
+ if self.abc_boundary_layer_type == "PML":
+ return self.X_np1
+ else:
+ return self.u_np1
+
+ @override
+ def get_receivers_output(self):
+ if self.abc_boundary_layer_type == "PML":
+ data_with_halos = self.X_n.dat.data_ro_with_halos[0][:]
+ else:
+ data_with_halos = self.u_n.dat.data_ro_with_halos[:]
+ return self.receivers.interpolate(data_with_halos)
+
+ @override
+ def get_function(self):
+ if self.abc_boundary_layer_type == "PML":
+ return self.X_n.sub(0)
+ else:
+ return self.u_n
+
+ @override
+ def get_function_name(self):
+ return "Pressure"
diff --git a/spyro/solvers/dg_wave.py b/spyro/solvers/dg_wave.py
deleted file mode 100644
index 0abe5e89..00000000
--- a/spyro/solvers/dg_wave.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# import firedrake as fire
-# from firedrake import dot, grad, jump, avg, dx, ds, dS, Constant
-# from spyro import Wave
-
-# fire.set_log_level(fire.ERROR)
-
-
-# class DG_Wave(Wave):
-# def matrix_building(self):
-# """Builds solver operators. Doesn't create mass matrices if
-# matrix_free option is on,
-# which it is by default.
-# """
-# V = self.function_space
-# # Trial and test functions
-# u = fire.TrialFunction(V)
-# v = fire.TestFunction(V)
-
-# # # Previous functions for time integration
-# u_n = fire.Function(V)
-# u_nm1 = fire.Function(V)
-# self.u_nm1 = u_nm1
-# self.u_n = u_n
-# c = self.c
-
-# self.current_time = 0.0
-# dt = self.dt
-
-# # Normal component, cell size and right-hand side
-# h = fire.CellDiameter(self.mesh)
-# h_avg = (h("+") + h("-")) / 2
-# n = fire.FacetNormal(self.mesh)
-
-# # Parameters
-# alpha = 4.0
-# gamma = 8.0
-
-# # Bilinear form
-# a = (
-# dot(grad(v), grad(u)) * dx
-# - dot(avg(grad(v)), jump(u, n)) * dS
-# - dot(jump(v, n), avg(grad(u))) * dS
-# + alpha / h_avg * dot(jump(v, n), jump(u, n)) * dS
-# - dot(grad(v), u * n) * ds
-# - dot(v * n, grad(u)) * ds
-# + (gamma / h) * v * u * ds
-# + ((u) / Constant(dt**2)) / c * v * dx
-# )
-# # Linear form
-# b = ((2.0 * u_n - u_nm1) / Constant(dt**2)) / c * v * dx
-# form = a - b
-
-# lhs = fire.lhs(form)
-# rhs = fire.rhs(form)
-
-# A = fire.assemble(lhs)
-# params = {"ksp_type": "gmres"}
-# self.solver = fire.LinearSolver(A, solver_parameters=params)
-
-# # lterar para como o thiago fez
-# self.rhs = rhs
-# B = fire.Function(V)
-# self.B = B
diff --git a/spyro/solvers/elastic_wave/elastic_solver_construction_no_pml.py b/spyro/solvers/elastic_wave/elastic_solver_construction_no_pml.py
new file mode 100644
index 00000000..e69de29b
diff --git a/spyro/solvers/elastic_wave/elastic_solver_construction_with_pml.py b/spyro/solvers/elastic_wave/elastic_solver_construction_with_pml.py
new file mode 100644
index 00000000..e69de29b
diff --git a/spyro/solvers/elastic_wave/elastic_wave.py b/spyro/solvers/elastic_wave/elastic_wave.py
new file mode 100644
index 00000000..d18d2c0e
--- /dev/null
+++ b/spyro/solvers/elastic_wave/elastic_wave.py
@@ -0,0 +1,29 @@
+from abc import abstractmethod
+
+from ..wave import Wave
+
+class ElasticWave(Wave):
+ '''Base class for elastic wave propagators'''
+ def __init__(self, dictionary, comm=None):
+ super().__init__(dictionary, comm=comm)
+
+ #@override
+ def _initialize_model_parameters(self):
+ d = self.input_dictionary.get("synthetic_data", False)
+ if bool(d) and "type" in d:
+ if d["type"] == "object":
+ self.initialize_model_parameters_from_object(d)
+ elif d["type"] == "file":
+ self.initialize_model_parameters_from_file(d)
+ else:
+ raise Exception(f"Invalid synthetic data type: {d['type']}")
+ else:
+ raise Exception("Input dictionary must contain ['synthetic_data']['type']")
+
+ @abstractmethod
+ def initialize_model_parameters_from_object(self, synthetic_data_dict):
+ pass
+
+ @abstractmethod
+ def initialize_model_parameters_from_file(self, synthetic_data_dict):
+ pass
diff --git a/spyro/solvers/elastic_wave/isotropic_wave.py b/spyro/solvers/elastic_wave/isotropic_wave.py
new file mode 100644
index 00000000..52caf4e7
--- /dev/null
+++ b/spyro/solvers/elastic_wave/isotropic_wave.py
@@ -0,0 +1,48 @@
+from .elastic_wave import ElasticWave
+
+class IsotropicWave(ElasticWave):
+ '''Isotropic elastic wave propagator'''
+ def __init__(self, dictionary, comm=None):
+ super().__init__(dictionary, comm=comm)
+
+ self.rho = None # Density
+ self.lmbda = None # First Lame parameter
+ self.mu = None # Second Lame parameter
+ self.c_s = None # Secondary wave velocity
+
+ #@override
+ def initialize_model_parameters_from_object(self, synthetic_data_dict: dict):
+ self.rho = synthetic_data_dict.get("density", None)
+ self.lmbda = synthetic_data_dict.get("lambda",
+ synthetic_data_dict.get("lame_first", None))
+ self.mu = synthetic_data_dict.get("mu",
+ synthetic_data_dict.get("lame_second", None))
+ self.c = synthetic_data_dict.get("p_wave_velocity", None)
+ self.c_s = synthetic_data_dict.get("s_wave_velocity", None)
+
+ # Check if {rho, lambda, mu} is set and {c, c_s} are not
+ option_1 = bool(self.rho) and \
+ bool(self.lmbda) and \
+ bool(self.mu) and \
+ not bool(self.c) and \
+ not bool(self.c_s)
+ # Check if {rho, c, c_s} is set and {lambda, mu} are not
+ option_2 = bool(self.rho) and \
+ bool(self.c) and \
+ bool(self.c_s) and \
+ not bool(self.lmbda) and \
+ not bool(self.mu)
+
+ if not option_1 and not option_2:
+ raise Exception(f"Inconsistent selection of isotropic elastic wave parameters:\n" \
+ f" Density : {bool(self.rho)}\n"\
+ f" Lame first : {bool(self.lmbda)}\n"\
+ f" Lame second : {bool(self.mu)}\n"\
+ f" P-wave velocity: {bool(self.c)}\n"\
+ f" S-wave velocity: {bool(self.c_s)}\n"\
+ "The valid options are \{Density, Lame first, Lame second\} "\
+ "or \{Density, P-wave velocity, S-wave velocity\}")
+
+ #@override
+ def initialize_model_parameters_from_file(self, synthetic_data_dict):
+ raise NotImplementedError
diff --git a/spyro/solvers/forward.py b/spyro/solvers/forward.py
deleted file mode 100644
index 06d38e85..00000000
--- a/spyro/solvers/forward.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# import firedrake as fire
-# from firedrake.assemble import create_assembly_callable
-# from firedrake import Constant, dx, dot, inner, grad, ds
-# import FIAT
-# import finat
-
-# from .. import utils
-# from ..domains import quadrature, space
-# from ..pml import damping
-# from ..io import ensemble_forward
-# from . import helpers
-
-
-# def gauss_lobatto_legendre_line_rule(degree):
-# fiat_make_rule = FIAT.quadrature.GaussLobattoLegendreQuadratureLineRule
-# fiat_rule = fiat_make_rule(FIAT.ufc_simplex(1), degree + 1)
-# finat_ps = finat.point_set.GaussLobattoLegendrePointSet
-# finat_qr = finat.quadrature.QuadratureRule
-# return finat_qr(finat_ps(fiat_rule.get_points()), fiat_rule.get_weights())
-
-
-# # 3D
-# def gauss_lobatto_legendre_cube_rule(dimension, degree):
-# """Returns GLL integration rule
-
-# Parameters
-# ----------
-# dimension: `int`
-# The dimension of the mesh
-# degree: `int`
-# The degree of the function space
-
-# Returns
-# -------
-# result: `finat.quadrature.QuadratureRule`
-# The GLL integration rule
-# """
-# make_tensor_rule = finat.quadrature.TensorProductQuadratureRule
-# result = gauss_lobatto_legendre_line_rule(degree)
-# for _ in range(1, dimension):
-# line_rule = gauss_lobatto_legendre_line_rule(degree)
-# result = make_tensor_rule([result, line_rule])
-# return result
-
-
-# @ensemble_forward
-# def forward(
-# model,
-# mesh,
-# comm,
-# c,
-# excitations,
-# wavelet,
-# receivers,
-# source_num=0,
-# output=False,
-# ):
-# """Secord-order in time fully-explicit scheme
-# with implementation of a Perfectly Matched Layer (PML) using
-# CG FEM with or without higher order mass lumping (KMV type elements).
-
-# Parameters
-# ----------
-# model: Python `dictionary`
-# Contains model options and parameters
-# mesh: Firedrake.mesh object
-# The 2D/3D triangular mesh
-# comm: Firedrake.ensemble_communicator
-# The MPI communicator for parallelism
-# c: Firedrake.Function
-# The velocity model interpolated onto the mesh.
-# excitations: A list Firedrake.Functions
-# wavelet: array-like
-# Time series data that's injected at the source location.
-# receivers: A :class:`spyro.Receivers` object.
-# Contains the receiver locations and sparse interpolation methods.
-# source_num: `int`, optional
-# The source number you wish to simulate
-# output: `boolean`, optional
-# Whether or not to write results to pvd files.
-
-# Returns
-# -------
-# usol: list of Firedrake.Functions
-# The full field solution at `fspool` timesteps
-# usol_recv: array-like
-# The solution interpolated to the receivers at all timesteps
-
-# """
-
-# method = model["opts"]["method"]
-# degree = model["opts"]["degree"]
-# dt = model["timeaxis"]["dt"]
-# final_time = model["timeaxis"]["tf"]
-# nspool = model["timeaxis"]["nspool"]
-# fspool = model["timeaxis"]["fspool"]
-# excitations.current_source = source_num
-
-# nt = int(final_time / dt) # number of timesteps
-
-# element = fire.FiniteElement(method, mesh.ufl_cell(), degree=degree)
-
-# V = fire.FunctionSpace(mesh, element)
-
-# # typical CG FEM in 2d/3d
-# u = fire.TrialFunction(V)
-# v = fire.TestFunction(V)
-
-# u_nm1 = fire.Function(V)
-# u_n = fire.Function(V, name="pressure")
-# u_np1 = fire.Function(V)
-
-# if output:
-# outfile = helpers.create_output_file("forward.pvd", comm, source_num)
-
-# t = 0.0
-
-# # -------------------------------------------------------
-# m1 = ((u) / Constant(dt**2)) * v * dx
-# a = (
-# c * c * dot(grad(u_n), grad(v)) * dx
-# + ((-2.0 * u_n + u_nm1) / Constant(dt**2)) * v * dx
-# ) # explicit
-
-# X = fire.Function(V)
-# B = fire.Function(V)
-
-# lhs = m1
-# rhs = -a
-
-# A = fire.assemble(lhs)
-# solver = fire.LinearSolver(A)
-
-# usol = [
-# fire.Function(V, name="pressure") for t in range(nt) if t % fspool == 0
-# ]
-# usol_recv = []
-# save_step = 0
-
-# assembly_callable = create_assembly_callable(rhs, tensor=B)
-
-# rhs_forcing = fire.Function(V)
-
-# for step in range(nt):
-# rhs_forcing.assign(0.0)
-# assembly_callable()
-# f = excitations.apply_source(rhs_forcing, wavelet[step])
-# B0 = B.sub(0)
-# B0 += f
-# solver.solve(X, B)
-
-# u_np1.assign(X)
-
-# usol_recv.append(
-# receivers.interpolate(u_np1.dat.data_ro_with_halos[:])
-# )
-
-# if step % fspool == 0:
-# usol[save_step].assign(u_np1)
-# save_step += 1
-
-# if step % nspool == 0:
-# assert (
-# fire.norm(u_n) < 1
-# ), "Numerical instability. Try reducing dt or building the mesh differently"
-# if output:
-# outfile.write(u_n, time=t, name="Pressure")
-# if t > 0:
-# helpers.display_progress(comm, t)
-
-# u_nm1.assign(u_n)
-# u_n.assign(u_np1)
-
-# t = step * float(dt)
-
-# usol_recv = helpers.fill(
-# usol_recv, receivers.is_local, nt, receivers.num_receivers
-# )
-# usol_recv = utils.communicate(usol_recv, comm)
-
-# return usol, usol_recv
diff --git a/spyro/solvers/forward_AD.py b/spyro/solvers/forward_AD.py
deleted file mode 100644
index a8e257f3..00000000
--- a/spyro/solvers/forward_AD.py
+++ /dev/null
@@ -1,180 +0,0 @@
-# from firedrake import *
-
-# # from .. import utils
-# from ..domains import quadrature, space
-
-# # from ..pml import damping
-# # from ..io import ensemble_forward
-# from . import helpers
-
-# # Note this turns off non-fatal warnings
-# set_log_level(ERROR)
-
-
-# # @ensemble_forward
-# def forward(
-# model,
-# mesh,
-# comm,
-# c,
-# excitations,
-# wavelet,
-# receivers,
-# source_num=0,
-# output=False,
-# **kwargs
-# ):
-# """Secord-order in time fully-explicit scheme
-# with implementation of a Perfectly Matched Layer (PML) using
-# CG FEM with or without higher order mass lumping (KMV type elements).
-
-# Parameters
-# ----------
-# model: Python `dictionary`
-# Contains model options and parameters
-# mesh: Firedrake.mesh object
-# The 2D/3D triangular mesh
-# comm: Firedrake.ensemble_communicator
-# The MPI communicator for parallelism
-# c: Firedrake.Function
-# The velocity model interpolated onto the mesh.
-# excitations: A list Firedrake.Functions
-# wavelet: array-like
-# Time series data that's injected at the source location.
-# receivers: A :class:`spyro.Receivers` object.
-# Contains the receiver locations and sparse interpolation methods.
-# source_num: `int`, optional
-# The source number you wish to simulate
-# output: `boolean`, optional
-# Whether or not to write results to pvd files.
-
-# Returns
-# -------
-# usol: list of Firedrake.Functions
-# The full field solution at `fspool` timesteps
-# usol_recv: array-like
-# The solution interpolated to the receivers at all timesteps
-
-# """
-
-# method = model["opts"]["method"]
-# degree = model["opts"]["degree"]
-# dim = model["opts"]["dimension"]
-# dt = model["timeaxis"]["dt"]
-# tf = model["timeaxis"]["tf"]
-# nspool = model["timeaxis"]["nspool"]
-# nt = int(tf / dt) # number of timesteps
-# excitations.current_source = source_num
-# params = set_params(method)
-# element = space.FE_method(mesh, method, degree)
-
-# V = FunctionSpace(mesh, element)
-
-# qr_x, qr_s, _ = quadrature.quadrature_rules(V)
-
-# if dim == 2:
-# z, x = SpatialCoordinate(mesh)
-# elif dim == 3:
-# z, x, y = SpatialCoordinate(mesh)
-
-# u = TrialFunction(V)
-# v = TestFunction(V)
-
-# u_nm1 = Function(V)
-# u_n = Function(V)
-# u_np1 = Function(V)
-
-# if output:
-# outfile = helpers.create_output_file("forward.pvd", comm, source_num)
-
-# t = 0.0
-# m = 1 / (c * c)
-# m1 = (
-# m * ((u - 2.0 * u_n + u_nm1) / Constant(dt**2)) * v * dx(scheme=qr_x)
-# )
-# a = dot(grad(u_n), grad(v)) * dx(scheme=qr_x) # explicit
-# f = Function(V)
-# nf = 0
-
-# if model["BCs"]["outer_bc"] == "non-reflective":
-# nf = c * ((u_n - u_nm1) / dt) * v * ds(scheme=qr_s)
-
-# h = CellSize(mesh)
-# FF = (
-# m1 + a + nf - (1 / (h / degree * h / degree)) * f * v * dx(scheme=qr_x)
-# )
-# X = Function(V)
-
-# lhs_ = lhs(FF)
-# rhs_ = rhs(FF)
-
-# problem = LinearVariationalProblem(lhs_, rhs_, X)
-# solver = LinearVariationalSolver(problem, solver_parameters=params)
-
-# usol_recv = []
-
-# P = FunctionSpace(receivers, "DG", 0)
-# interpolator = Interpolator(u_np1, P)
-# J0 = 0.0
-
-# for step in range(nt):
-# excitations.apply_source(f, wavelet[step])
-
-# solver.solve()
-# u_np1.assign(X)
-
-# rec = Function(P)
-# interpolator.interpolate(output=rec)
-
-# fwi = kwargs.get("fwi")
-# p_true_rec = kwargs.get("true_rec")
-
-# usol_recv.append(rec.dat.data)
-
-# if fwi:
-# J0 += calc_objective_func(rec, p_true_rec[step], step, dt, P)
-
-# if step % nspool == 0:
-# assert (
-# norm(u_n) < 1
-# ), "Numerical instability. Try reducing dt or building the mesh differently"
-# if output:
-# outfile.write(u_n, time=t, name="Pressure")
-# if t > 0:
-# helpers.display_progress(comm, t)
-
-# u_nm1.assign(u_n)
-# u_n.assign(u_np1)
-
-# t = step * float(dt)
-
-# if fwi:
-# return usol_recv, J0
-# else:
-# return usol_recv
-
-
-# def calc_objective_func(p_rec, p_true_rec, IT, dt, P):
-# true_rec = Function(P)
-# true_rec.dat.data[:] = p_true_rec
-# J = 0.5 * assemble(inner(true_rec - p_rec, true_rec - p_rec) * dx)
-# return J
-
-
-# def set_params(method):
-# if method == "KMV":
-# params = {"ksp_type": "preonly", "pc_type": "jacobi"}
-# elif (
-# method == "CG"
-# and mesh.ufl_cell() != quadrilateral
-# and mesh.ufl_cell() != hexahedron
-# ):
-# params = {"ksp_type": "cg", "pc_type": "jacobi"}
-# elif method == "CG" and (
-# mesh.ufl_cell() == quadrilateral or mesh.ufl_cell() == hexahedron
-# ):
-# params = {"ksp_type": "preonly", "pc_type": "jacobi"}
-# else:
-# raise ValueError("method is not yet supported")
-
-# return params
diff --git a/spyro/solvers/gradient_old.py b/spyro/solvers/gradient_old.py
deleted file mode 100644
index 3bcde5a2..00000000
--- a/spyro/solvers/gradient_old.py
+++ /dev/null
@@ -1,210 +0,0 @@
-# import firedrake as fire
-# from firedrake import dx, ds, Constant, grad, inner, dot
-# from firedrake.assemble import create_assembly_callable
-
-# from ..domains import quadrature, space
-# from ..pml import damping
-# from ..io import ensemble_gradient
-# from . import helpers
-
-# # Note this turns off non-fatal warnings
-# # set_log_level(ERROR)
-
-# __all__ = ["gradient"]
-
-
-# def gauss_lobatto_legendre_line_rule(degree):
-# fiat_make_rule = FIAT.quadrature.GaussLobattoLegendreQuadratureLineRule
-# fiat_rule = fiat_make_rule(FIAT.ufc_simplex(1), degree + 1)
-# finat_ps = finat.point_set.GaussLobattoLegendrePointSet
-# finat_qr = finat.quadrature.QuadratureRule
-# return finat_qr(finat_ps(fiat_rule.get_points()), fiat_rule.get_weights())
-
-
-# # 3D
-# def gauss_lobatto_legendre_cube_rule(dimension, degree):
-# make_tensor_rule = finat.quadrature.TensorProductQuadratureRule
-# result = gauss_lobatto_legendre_line_rule(degree)
-# for _ in range(1, dimension):
-# line_rule = gauss_lobatto_legendre_line_rule(degree)
-# result = make_tensor_rule([result, line_rule])
-# return result
-
-
-# @ensemble_gradient
-# def gradient(
-# model,
-# mesh,
-# comm,
-# c,
-# receivers,
-# guess,
-# residual,
-# output=False,
-# save_adjoint=False,
-# ):
-# """Discrete adjoint with secord-order in time fully-explicit timestepping scheme
-# with implementation of a Perfectly Matched Layer (PML) using
-# CG FEM with or without higher order mass lumping (KMV type elements).
-
-# Parameters
-# ----------
-# model: Python `dictionary`
-# Contains model options and parameters
-# mesh: Firedrake.mesh object
-# The 2D/3D triangular mesh
-# comm: Firedrake.ensemble_communicator
-# The MPI communicator for parallelism
-# c: Firedrake.Function
-# The velocity model interpolated onto the mesh nodes.
-# receivers: A :class:`spyro.Receivers` object.
-# Contains the receiver locations and sparse interpolation methods.
-# guess: A list of Firedrake functions
-# Contains the forward wavefield at a set of timesteps
-# residual: array-like [timesteps][receivers]
-# The difference between the observed and modeled data at
-# the receivers
-# output: boolean
-# optional, write the adjoint to disk (only for debugging)
-# save_adjoint: A list of Firedrake functions
-# Contains the adjoint at all timesteps
-
-# Returns
-# -------
-# dJdc_local: A Firedrake.Function containing the gradient of
-# the functional w.r.t. `c`
-# adjoint: Optional, a list of Firedrake functions containing the adjoint
-
-# """
-
-# method = model["opts"]["method"]
-# degree = model["opts"]["degree"]
-# dimension = model["opts"]["dimension"]
-# dt = model["timeaxis"]["dt"]
-# tf = model["timeaxis"]["tf"]
-# nspool = model["timeaxis"]["nspool"]
-# fspool = model["timeaxis"]["fspool"]
-
-# params = {"ksp_type": "cg", "pc_type": "jacobi"}
-
-# element = fire.FiniteElement(
-# method, mesh.ufl_cell(), degree=degree, variant="spectral"
-# )
-
-# V = fire.FunctionSpace(mesh, element)
-
-# qr_x = gauss_lobatto_legendre_cube_rule(dimension=dimension, degree=degree)
-# qr_s = gauss_lobatto_legendre_cube_rule(
-# dimension=(dimension - 1), degree=degree
-# )
-
-# nt = int(tf / dt) # number of timesteps
-
-# dJ = fire.Function(V, name="gradient")
-
-# # typical CG in N-d
-# u = fire.TrialFunction(V)
-# v = fire.TestFunction(V)
-
-# u_nm1 = fire.Function(V)
-# u_n = fire.Function(V)
-# u_np1 = fire.Function(V)
-
-# if output:
-# outfile = helpers.create_output_file("adjoint.pvd", comm, 0)
-
-# t = 0.0
-
-# # -------------------------------------------------------
-# m1 = ((u - 2.0 * u_n + u_nm1) / Constant(dt**2)) * v * dx(scheme=qr_x)
-# a = c * c * dot(grad(u_n), grad(v)) * dx(scheme=qr_x) # explicit
-
-# lhs1 = m1
-# rhs1 = -a
-
-# X = fire.Function(V)
-# B = fire.Function(V)
-
-# A = fire.assemble(lhs1, mat_type="matfree")
-# solver = fire.LinearSolver(A, solver_parameters=params)
-
-# # Define gradient problem
-# m_u = fire.TrialFunction(V)
-# m_v = fire.TestFunction(V)
-# mgrad = m_u * m_v * dx(rule=qr_x)
-
-# uuadj = fire.Function(V) # auxiliarly function for the gradient compt.
-# uufor = fire.Function(V) # auxiliarly function for the gradient compt.
-
-# ffG = 2.0 * c * dot(grad(uuadj), grad(uufor)) * m_v * dx(scheme=qr_x)
-
-# lhsG = mgrad
-# rhsG = ffG
-
-# gradi = fire.Function(V)
-# grad_prob = fire.LinearVariationalProblem(lhsG, rhsG, gradi)
-
-# if method == "KMV":
-# grad_solver = fire.LinearVariationalSolver(
-# grad_prob,
-# solver_parameters={
-# "ksp_type": "preonly",
-# "pc_type": "jacobi",
-# "mat_type": "matfree",
-# },
-# )
-# elif method == "CG":
-# grad_solver = fire.LinearVariationalSolver(
-# grad_prob,
-# solver_parameters={
-# "mat_type": "matfree",
-# },
-# )
-
-# assembly_callable = create_assembly_callable(rhs1, tensor=B)
-
-# rhs_forcing = fire.Function(V) # forcing term
-# if save_adjoint:
-# adjoint = [
-# fire.Function(V, name="adjoint_pressure") for t in range(nt)
-# ]
-# for step in range(nt - 1, -1, -1):
-# t = step * float(dt)
-# rhs_forcing.assign(0.0)
-# # Solver - main equation - (I)
-# # B = assemble(rhs_, tensor=B)
-# assembly_callable()
-
-# f = receivers.apply_receivers_as_source(rhs_forcing, residual, step)
-# # add forcing term to solve scalar pressure
-# B0 = B.sub(0)
-# B0 += f
-
-# # AX=B --> solve for X = B/Aˆ-1
-# solver.solve(X, B)
-
-# u_np1.assign(X)
-
-# # only compute for snaps that were saved
-# if step % fspool == 0:
-# # compute the gradient increment
-# uuadj.assign(u_np1)
-# uufor.assign(guess.pop())
-
-# grad_solver.solve()
-# dJ += gradi
-
-# u_nm1.assign(u_n)
-# u_n.assign(u_np1)
-
-# if step % nspool == 0:
-# if output:
-# outfile.write(u_n, time=t)
-# if save_adjoint:
-# adjoint.append(u_n)
-# helpers.display_progress(comm, t)
-
-# if save_adjoint:
-# return dJ, adjoint
-# else:
-# return dJ
diff --git a/spyro/solvers/mms_acoustic.py b/spyro/solvers/mms_acoustic.py
index ccbfb28c..6cfebd3f 100644
--- a/spyro/solvers/mms_acoustic.py
+++ b/spyro/solvers/mms_acoustic.py
@@ -1,6 +1,6 @@
import firedrake as fire
from .acoustic_wave import AcousticWave
-
+from ..utils.typing import override
class AcousticWaveMMS(AcousticWave):
"""Class for solving the acoustic wave equation in 2D or 3D using
@@ -10,14 +10,21 @@ class AcousticWaveMMS(AcousticWave):
"""
def matrix_building(self):
+ self.mms_source_in_space()
+ self.q_t = fire.Constant(0)
+ self.source_expression = self.q_t * self.q_xy
+
super().matrix_building()
lhs = self.lhs
bcs = fire.DirichletBC(self.function_space, 0.0, "on_boundary")
A = fire.assemble(lhs, bcs=bcs, mat_type="matfree")
- self.mms_source_in_space()
self.solver = fire.LinearSolver(
A, solver_parameters=self.solver_parameters
)
+ dt = self.dt
+ t = self.current_time
+ self.u_nm1.assign(self.analytical_solution(t - 2 * dt))
+ self.u_n.assign(self.analytical_solution(t - dt))
def mms_source_in_space(self):
V = self.function_space
@@ -45,10 +52,6 @@ def mms_source_in_space(self):
# self.q_xy.interpolate(sin(pi*x)*sin(pi*y))
- def mms_source_in_time(self, t):
- # return fire.Constant(2*pi**2*t**2 + 2.0)
- return fire.Constant(2 * t)
-
def analytical_solution(self, t):
self.analytical = fire.Function(self.function_space)
x = self.mesh_z
@@ -66,3 +69,7 @@ def analytical_solution(self, t):
# self.analytical.assign(analytical)
return self.analytical
+
+ @override
+ def update_source_expression(self, t):
+ self.q_t.assign(2*t)
diff --git a/spyro/solvers/time_integration.py b/spyro/solvers/time_integration.py
index 889fd834..45573f12 100644
--- a/spyro/solvers/time_integration.py
+++ b/spyro/solvers/time_integration.py
@@ -1,26 +1,4 @@
from .time_integration_central_difference import central_difference
-from .time_integration_central_difference import mixed_space_central_difference
-from .time_integration_central_difference import central_difference_MMS
-
def time_integrator(Wave_object, source_id=0):
- if Wave_object.source_type == "ricker":
- return time_integrator_ricker(Wave_object, source_id=source_id)
- elif Wave_object.source_type == "MMS":
- return time_integrator_mms(Wave_object, source_id=source_id)
-
-
-def time_integrator_ricker(Wave_object, source_id=0):
- if Wave_object.time_integrator == "central_difference":
- return central_difference(Wave_object, source_id=source_id)
- elif Wave_object.time_integrator == "mixed_space_central_difference":
- return mixed_space_central_difference(Wave_object, source_id=source_id)
- else:
- raise ValueError("The time integrator specified is not implemented yet")
-
-
-def time_integrator_mms(Wave_object, source_id=0):
- if Wave_object.time_integrator == "central_difference":
- return central_difference_MMS(Wave_object, source_id=source_id)
- else:
- raise ValueError("The time integrator specified is not implemented yet")
+ return central_difference(Wave_object, source_id=source_id)
diff --git a/spyro/solvers/time_integration_central_difference.py b/spyro/solvers/time_integration_central_difference.py
index 410bf5cd..3bb4541b 100644
--- a/spyro/solvers/time_integration_central_difference.py
+++ b/spyro/solvers/time_integration_central_difference.py
@@ -1,18 +1,17 @@
import firedrake as fire
-from firedrake import Constant, dx, dot, grad
from ..io.basicio import parallel_print
from . import helpers
from .. import utils
-def central_difference(Wave_object, source_id=0):
+def central_difference(wave, source_id=0):
"""
Perform central difference time integration for wave propagation.
Parameters:
-----------
- Wave_object: Spyro object
+ wave: Spyro object
The Wave object containing the necessary data and parameters.
source_id: int (optional)
The ID of the source being propagated. Defaults to 0.
@@ -22,296 +21,74 @@ def central_difference(Wave_object, source_id=0):
tuple:
A tuple containing the forward solution and the receiver output.
"""
- excitations = Wave_object.sources
- excitations.current_source = source_id
- receivers = Wave_object.receivers
- comm = Wave_object.comm
- temp_filename = Wave_object.forward_output_file
+ if wave.sources is not None:
+ wave.sources.current_source = source_id
+ rhs_forcing = fire.Function(wave.function_space)
- filename, file_extension = temp_filename.split(".")
+ filename, file_extension = wave.forward_output_file.split(".")
output_filename = filename + "sn" + str(source_id) + "." + file_extension
- if Wave_object.forward_output:
- parallel_print(f"Saving output in: {output_filename}", Wave_object.comm)
-
- output = fire.File(output_filename, comm=comm.comm)
- comm.comm.barrier()
-
- X = fire.Function(Wave_object.function_space)
-
- final_time = Wave_object.final_time
- dt = Wave_object.dt
- t = Wave_object.current_time
- nt = int((final_time - t) / dt) + 1 # number of timesteps
-
- u_nm1 = Wave_object.u_nm1
- u_n = Wave_object.u_n
- u_np1 = fire.Function(Wave_object.function_space)
-
- rhs_forcing = fire.Function(Wave_object.function_space)
- usol = [
- fire.Function(Wave_object.function_space, name="pressure")
- for t in range(nt)
- if t % Wave_object.gradient_sampling_frequency == 0
- ]
- usol_recv = []
- save_step = 0
- B = Wave_object.B
- rhs = Wave_object.rhs
-
- for step in range(nt):
- rhs_forcing.assign(0.0)
- B = fire.assemble(rhs, tensor=B)
- f = excitations.apply_source(rhs_forcing, Wave_object.wavelet[step])
- B0 = B.sub(0)
- B0 += f
- Wave_object.solver.solve(X, B)
-
- u_np1.assign(X)
-
- usol_recv.append(
- Wave_object.receivers.interpolate(u_np1.dat.data_ro_with_halos[:])
- )
-
- if step % Wave_object.gradient_sampling_frequency == 0:
- usol[save_step].assign(u_np1)
- save_step += 1
-
- if (step - 1) % Wave_object.output_frequency == 0:
- assert (
- fire.norm(u_n) < 1
- ), "Numerical instability. Try reducing dt or building the \
- mesh differently"
- if Wave_object.forward_output:
- output.write(u_n, time=t, name="Pressure")
-
- helpers.display_progress(Wave_object.comm, t)
-
- u_nm1.assign(u_n)
- u_n.assign(u_np1)
-
- t = step * float(dt)
-
- Wave_object.current_time = t
- helpers.display_progress(Wave_object.comm, t)
-
- usol_recv = helpers.fill(
- usol_recv, receivers.is_local, nt, receivers.number_of_points
- )
- usol_recv = utils.utils.communicate(usol_recv, comm)
- Wave_object.receivers_output = usol_recv
-
- Wave_object.forward_solution = usol
- Wave_object.forward_solution_receivers = usol_recv
-
- return usol, usol_recv
-
-
-def mixed_space_central_difference(Wave_object, source_id=0):
- """
- Performs central difference time integration for wave propagation.
- Solves for a mixed space formulation, for function X. For correctly
- outputing pressure, order the mixed function space so that the space
- pressure lives in is first.
-
- Parameters:
- -----------
- Wave_object: Spyro object
- The Wave object containing the necessary data and parameters.
- source_id: int (optional)
- The ID of the source being propagated. Defaults to 0.
-
- Returns:
- --------
- tuple:
- A tuple containing the forward solution and the receiver output.
- """
- excitations = Wave_object.sources
- excitations.current_source = source_id
- receivers = Wave_object.receivers
- comm = Wave_object.comm
- temp_filename = Wave_object.forward_output_file
- filename, file_extension = temp_filename.split(".")
- output_filename = filename + "sn" + str(source_id) + "." + file_extension
- if Wave_object.forward_output:
- parallel_print(f"Saving output in: {output_filename}", Wave_object.comm)
-
- output = fire.File(output_filename, comm=comm.comm)
- comm.comm.barrier()
-
- final_time = Wave_object.final_time
- dt = Wave_object.dt
- t = Wave_object.current_time
- nt = int(final_time / dt) + 1 # number of timesteps
-
- X = Wave_object.X
- X_n = Wave_object.X_n
- X_nm1 = Wave_object.X_nm1
-
- rhs_forcing = fire.Function(Wave_object.function_space)
+ if wave.forward_output:
+ parallel_print(f"Saving output in: {output_filename}", wave.comm)
+
+ output = fire.File(output_filename, comm=wave.comm.comm)
+ wave.comm.comm.barrier()
+
+ t = wave.current_time
+ nt = int(wave.final_time / wave.dt) + 1 # number of timesteps
+
usol = [
- fire.Function(Wave_object.function_space, name="pressure")
+ fire.Function(wave.function_space, name=wave.get_function_name())
for t in range(nt)
- if t % Wave_object.gradient_sampling_frequency == 0
+ if t % wave.gradient_sampling_frequency == 0
]
usol_recv = []
save_step = 0
- B = Wave_object.B
- rhs_ = Wave_object.rhs
-
for step in range(nt):
- rhs_forcing.assign(0.0)
- B = fire.assemble(rhs_, tensor=B)
- f = excitations.apply_source(rhs_forcing, Wave_object.wavelet[step])
- B0 = B.sub(0)
- B0 += f
- Wave_object.solver.solve(X, B)
-
- X_np1 = X
-
- X_nm1.assign(X_n)
- X_n.assign(X_np1)
-
- usol_recv.append(
- Wave_object.receivers.interpolate(
- X_np1.dat.data_ro_with_halos[0][:]
- )
- )
-
- if step % Wave_object.gradient_sampling_frequency == 0:
- usol[save_step].assign(X_np1.sub(0))
+ # Basic way of applying sources
+ wave.update_source_expression(t)
+ fire.assemble(wave.rhs, tensor=wave.B)
+
+ # More efficient way of applying sources
+ if wave.sources is not None:
+ rhs_forcing.assign(0.0)
+ f = wave.sources.apply_source(rhs_forcing, wave.wavelet[step])
+ B0 = wave.B.sub(0)
+ B0 += f
+
+ wave.solver.solve(wave.next_vstate, wave.B)
+
+ wave.prev_vstate = wave.vstate
+ wave.vstate = wave.next_vstate
+
+ usol_recv.append(wave.get_receivers_output())
+
+ if step % wave.gradient_sampling_frequency == 0:
+ usol[save_step].assign(wave.get_function())
save_step += 1
- if (step - 1) % Wave_object.output_frequency == 0:
+ if (step - 1) % wave.output_frequency == 0:
assert (
- fire.norm(X_np1.sub(0)) < 1
- ), "Numerical instability. Try reducing dt or building the \
- mesh differently"
- if Wave_object.forward_output:
- output.write(X_np1.sub(0), time=t, name="Pressure")
-
- helpers.display_progress(comm, t)
+ fire.norm(wave.get_function()) < 1
+ ), "Numerical instability. Try reducing dt or building the " \
+ "mesh differently"
+ if wave.forward_output:
+ output.write(wave.get_function(), time=t,
+ name=wave.get_function_name())
- t = step * float(dt)
+ helpers.display_progress(wave.comm, t)
- Wave_object.current_time = t
- helpers.display_progress(Wave_object.comm, t)
+ t = step * float(wave.dt)
+
+ wave.current_time = t
+ helpers.display_progress(wave.comm, t)
usol_recv = helpers.fill(
- usol_recv, receivers.is_local, nt, receivers.number_of_points
+ usol_recv, wave.receivers.is_local, nt, wave.receivers.number_of_points
)
- usol_recv = utils.utils.communicate(usol_recv, comm)
- Wave_object.receivers_output = usol_recv
-
- Wave_object.forward_solution = usol
- Wave_object.forward_solution_receivers = usol_recv
-
- return usol, usol_recv
-
-
-def central_difference_MMS(Wave_object, source_id=0):
- """Propagates the wave forward in time.
- Currently uses central differences.
-
- Parameters:
- -----------
- dt: Python 'float' (optional)
- Time step to be used explicitly. If not mentioned uses the default,
- that was estabilished in the model_parameters.
- final_time: Python 'float' (optional)
- Time which simulation ends. If not mentioned uses the default,
- that was estabilished in the model_parameters.
- """
- receivers = Wave_object.receivers
- comm = Wave_object.comm
- temp_filename = Wave_object.forward_output_file
- filename, file_extension = temp_filename.split(".")
- output_filename = filename + "sn_mms_" + "." + file_extension
- if Wave_object.forward_output:
- print(f"Saving output in: {output_filename}", flush=True)
-
- output = fire.File(output_filename, comm=comm.comm)
- comm.comm.barrier()
-
- X = fire.Function(Wave_object.function_space)
-
- final_time = Wave_object.final_time
- dt = Wave_object.dt
- t = Wave_object.current_time
- nt = int((final_time - t) / dt) + 1 # number of timesteps
-
- u_nm1 = Wave_object.u_nm1
- u_n = Wave_object.u_n
- u_nm1.assign(Wave_object.analytical_solution(t - 2 * dt))
- u_n.assign(Wave_object.analytical_solution(t - dt))
- u_np1 = fire.Function(Wave_object.function_space, name="pressure t +dt")
- u = fire.TrialFunction(Wave_object.function_space)
- v = fire.TestFunction(Wave_object.function_space)
-
- usol = [
- fire.Function(Wave_object.function_space, name="pressure")
- for t in range(nt)
- if t % Wave_object.gradient_sampling_frequency == 0
- ]
- usol_recv = []
- save_step = 0
- B = Wave_object.B
- rhs = Wave_object.rhs
- quad_rule = Wave_object.quadrature_rule
-
- q_xy = Wave_object.q_xy
-
- for step in range(nt):
- q = q_xy * Wave_object.mms_source_in_time(t)
- m1 = (
- 1
- / (Wave_object.c * Wave_object.c)
- * ((u - 2.0 * u_n + u_nm1) / Constant(dt**2))
- * v
- * dx(scheme=quad_rule)
- )
- a = dot(grad(u_n), grad(v)) * dx(scheme=quad_rule)
- le = q * v * dx(scheme=quad_rule)
-
- form = m1 + a - le
- rhs = fire.rhs(form)
-
- B = fire.assemble(rhs, tensor=B)
-
- Wave_object.solver.solve(X, B)
+ usol_recv = utils.utils.communicate(usol_recv, wave.comm)
+ wave.receivers_output = usol_recv
- u_np1.assign(X)
-
- usol_recv.append(
- Wave_object.receivers.interpolate(u_np1.dat.data_ro_with_halos[:])
- )
-
- if step % Wave_object.gradient_sampling_frequency == 0:
- usol[save_step].assign(u_np1)
- save_step += 1
-
- if (step - 1) % Wave_object.output_frequency == 0:
- assert (
- fire.norm(u_n) < 1
- ), "Numerical instability. Try reducing dt or building the \
- mesh differently"
- if Wave_object.forward_output:
- output.write(u_n, time=t, name="Pressure")
- if t > 0:
- helpers.display_progress(Wave_object.comm, t)
-
- u_nm1.assign(u_n)
- u_n.assign(u_np1)
-
- t = step * float(dt)
-
- Wave_object.current_time = t
- helpers.display_progress(Wave_object.comm, t)
- Wave_object.analytical_solution(t)
-
- usol_recv = helpers.fill(
- usol_recv, receivers.is_local, nt, receivers.number_of_points
- )
- usol_recv = utils.utils.communicate(usol_recv, comm)
- Wave_object.receivers_output = usol_recv
+ wave.forward_solution = usol
+ wave.forward_solution_receivers = usol_recv
- return usol, usol_recv
+ return usol, usol_recv
\ No newline at end of file
diff --git a/spyro/solvers/wave.py b/spyro/solvers/wave.py
index 4351d17c..293b29d2 100644
--- a/spyro/solvers/wave.py
+++ b/spyro/solvers/wave.py
@@ -87,6 +87,8 @@ def __init__(self, dictionary=None, comm=None):
)
else:
warnings.warn("No mesh found. Please define a mesh.")
+ # Expression to define sources through UFL (less efficient)
+ self.source_expression = None
@abstractmethod
def forward_solve(self):
@@ -209,8 +211,8 @@ def set_initial_velocity_model(
self.initial_velocity_model = vp
else:
raise ValueError(
- "Please specify either a conditional, expression, firedrake \
- function or new file name (segy or hdf5)."
+ "Please specify either a conditional, expression, firedrake " \
+ "function or new file name (segy or hdf5)."
)
if output:
fire.File("initial_velocity_model.pvd").write(
@@ -223,35 +225,10 @@ def _map_sources_and_receivers(self):
else:
self.sources = None
self.receivers = Receivers(self)
-
- def _get_initial_velocity_model(self):
- if self.initial_velocity_model is not None:
- return None
-
- if self.initial_velocity_model_file is None:
- raise ValueError("No velocity model or velocity file to load.")
-
- if self.initial_velocity_model_file.endswith(".segy"):
- vp_filename, vp_filetype = os.path.splitext(
- self.initial_velocity_model_file
- )
- warnings.warn("Converting segy file to hdf5")
- write_velocity_model(
- self.initial_velocity_model_file, ofname=vp_filename
- )
- self.initial_velocity_model_file = vp_filename + ".hdf5"
-
- if self.initial_velocity_model_file.endswith(".hdf5"):
- self.initial_velocity_model = interpolate(
- self,
- self.initial_velocity_model_file,
- self.function_space.sub(0),
- )
-
- if self.debug_output:
- fire.File("initial_velocity_model.pvd").write(
- self.initial_velocity_model, name="velocity"
- )
+
+ @abstractmethod
+ def _initialize_model_parameters(self):
+ pass
def _build_function_space(self):
self.function_space = FE_method(self.mesh, self.method, self.degree)
@@ -315,3 +292,56 @@ def set_last_solve_as_real_shot_record(self):
if self.current_time == 0.0:
raise ValueError("No previous solve to set as real shot record.")
self.real_shot_record = self.forward_solution_receivers
+
+ @abstractmethod
+ def _set_vstate(self, vstate):
+ pass
+
+ @abstractmethod
+ def _get_vstate(self):
+ pass
+
+ @abstractmethod
+ def _set_prev_vstate(self, vstate):
+ pass
+
+ @abstractmethod
+ def _get_prev_vstate(self):
+ pass
+
+ @abstractmethod
+ def _set_next_vstate(self, vstate):
+ pass
+
+ @abstractmethod
+ def _get_next_vstate(self):
+ pass
+
+ # Managed attributes to access state variables in current, previous and next iteration
+ vstate = property(fget=lambda self: self._get_vstate(),
+ fset=lambda self, value: self._set_vstate(value))
+ prev_vstate = property(fget=lambda self: self._get_prev_vstate(),
+ fset=lambda self, value: self._set_prev_vstate(value))
+ next_vstate = property(fget=lambda self: self._get_next_vstate(),
+ fset=lambda self, value: self._set_next_vstate(value))
+
+ @abstractmethod
+ def get_receivers_output(self):
+ pass
+
+ @abstractmethod
+ def get_function(self):
+ '''Returns the function (e.g., pressure or displacement) associated with
+ the wave object without additional variables (e.g., PML variables)'''
+ pass
+
+ @abstractmethod
+ def get_function_name(self):
+ '''Returns the string representing the function of the wave object
+ (e.g., "pressure" or "displacement")'''
+ pass
+
+ def update_source_expression(self, t):
+ '''Update the source expression during wave propagation. This method must be
+ implemented only by subclasses that make use of the source term'''
+ pass
diff --git a/spyro/tools/cells_per_wavelength_calculator.py b/spyro/tools/cells_per_wavelength_calculator.py
index 8c92391c..af6e7895 100644
--- a/spyro/tools/cells_per_wavelength_calculator.py
+++ b/spyro/tools/cells_per_wavelength_calculator.py
@@ -347,7 +347,7 @@ def find_minimum(self, starting_cpw=None, TOL=None, accuracy=None, savetxt=False
# Running forward model
Wave_obj = self.build_current_object(cpw)
- Wave_obj._get_initial_velocity_model()
+ Wave_obj._initialize_model_parameters() # TO REVIEW: call to protected method
# Setting up time-step
if self.timestep_calculation != "float":
diff --git a/spyro/utils/typing.py b/spyro/utils/typing.py
new file mode 100644
index 00000000..b393ccb6
--- /dev/null
+++ b/spyro/utils/typing.py
@@ -0,0 +1,6 @@
+def override(func):
+ '''
+ This decorator should be replaced by typing.override when Python
+ version is updated to 3.12
+ '''
+ return func
\ No newline at end of file
diff --git a/test/test_MMS.py b/test/test_MMS.py
index ab2cb44c..e44b90d1 100644
--- a/test/test_MMS.py
+++ b/test/test_MMS.py
@@ -35,7 +35,7 @@ def run_solve(model):
Wave_obj.set_initial_velocity_model(expression="1 + sin(pi*-z)*sin(pi*x)")
Wave_obj.forward_solve()
- u_an = Wave_obj.analytical
+ u_an = Wave_obj.analytical_solution(Wave_obj.current_time)
u_num = Wave_obj.u_n
return errornorm(u_num, u_an)
diff --git a/test/test_isotropic_wave.py b/test/test_isotropic_wave.py
new file mode 100644
index 00000000..ed5a3ccb
--- /dev/null
+++ b/test/test_isotropic_wave.py
@@ -0,0 +1,67 @@
+import pytest
+
+from spyro.solvers.elastic_wave.isotropic_wave import IsotropicWave
+
+# TO REVIEW: it is extra work to have to define this dictionary everytime
+# Here I listed only the required parameters for running to get a view of
+# what is currently necessary. Note that the dictionary is not even complete
+dummy_dict = {
+ "options": {
+ "cell_type": "T",
+ "variant": "lumped",
+ },
+ "time_axis": {
+ "final_time": 1,
+ "dt": 0.001,
+ "output_frequency": 100,
+ "gradient_sampling_frequency": 1,
+ },
+ "mesh": {},
+ "acquisition": {
+ "receiver_locations": [],
+ "source_type": "ricker",
+ "source_locations": [(0, 0)],
+ "frequency": 5.0,
+ },
+}
+
+def test_initialize_model_parameters_from_object_missing_parameters():
+ synthetic_dict = {
+ "type": "object",
+ }
+ wave = IsotropicWave(dummy_dict)
+ with pytest.raises(Exception) as e:
+ wave.initialize_model_parameters_from_object(synthetic_dict)
+
+def test_initialize_model_parameters_from_object_first_option():
+ synthetic_dict = {
+ "type": "object",
+ "density": 1,
+ "lambda": 2,
+ "mu": 3,
+ }
+ wave = IsotropicWave(dummy_dict)
+ wave.initialize_model_parameters_from_object(synthetic_dict)
+
+def test_initialize_model_parameters_from_object_second_option():
+ synthetic_dict = {
+ "type": "object",
+ "density": 1,
+ "p_wave_velocity": 2,
+ "s_wave_velocity": 3,
+ }
+ wave = IsotropicWave(dummy_dict)
+ wave.initialize_model_parameters_from_object(synthetic_dict)
+
+def test_initialize_model_parameters_from_object_redundant():
+ synthetic_dict = {
+ "type": "object",
+ "density": 1,
+ "lmbda": 2,
+ "mu": 3,
+ "p_wave_velocity": 2,
+ "s_wave_velocity": 3,
+ }
+ wave = IsotropicWave(dummy_dict)
+ with pytest.raises(Exception) as e:
+ wave.initialize_model_parameters_from_object(synthetic_dict)
\ No newline at end of file