Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 0041 - Add elastic wave propagation #53

Merged
merged 7 commits into from
Jul 22, 2024
7 changes: 6 additions & 1 deletion spyro/solvers/acoustic_solver_construction_no_pml.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@ def construct_solver_or_matrix_no_pml(Wave_object):
)
a = dot(grad(u_n), grad(v)) * dx(scheme=quad_rule) # explicit

le = 0
q = Wave_object.source_expression
if q is not None:
le = q * v * dx(scheme=quad_rule)

B = fire.Function(V)

form = m1 + a
form = m1 + a - le
lhs = fire.lhs(form)
rhs = fire.rhs(form)
Wave_object.lhs = lhs
Expand Down
19 changes: 13 additions & 6 deletions spyro/solvers/mms_acoustic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import firedrake as fire
from .acoustic_wave import AcousticWave

from ..utils.typing import override

class AcousticWaveMMS(AcousticWave):
"""Class for solving the acoustic wave equation in 2D or 3D using
Expand All @@ -10,14 +10,21 @@ class AcousticWaveMMS(AcousticWave):
"""

def matrix_building(self):
self.mms_source_in_space()
self.q_t = fire.Constant(0)
self.source_expression = self.q_t * self.q_xy

super().matrix_building()
lhs = self.lhs
bcs = fire.DirichletBC(self.function_space, 0.0, "on_boundary")
A = fire.assemble(lhs, bcs=bcs, mat_type="matfree")
self.mms_source_in_space()
self.solver = fire.LinearSolver(
A, solver_parameters=self.solver_parameters
)
dt = self.dt
t = self.current_time
self.u_nm1.assign(self.analytical_solution(t - 2 * dt))
self.u_n.assign(self.analytical_solution(t - dt))

def mms_source_in_space(self):
V = self.function_space
Expand Down Expand Up @@ -45,10 +52,6 @@ def mms_source_in_space(self):

# self.q_xy.interpolate(sin(pi*x)*sin(pi*y))

def mms_source_in_time(self, t):
# return fire.Constant(2*pi**2*t**2 + 2.0)
return fire.Constant(2 * t)

def analytical_solution(self, t):
self.analytical = fire.Function(self.function_space)
x = self.mesh_z
Expand All @@ -66,3 +69,7 @@ def analytical_solution(self, t):
# self.analytical.assign(analytical)

return self.analytical

@override
def update_source_expression(self, t):
self.q_t.assign(2*t)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm uncertain if this is the best practice. While using MMS for this small problem might not cause any issues, for larger cases in the future, we should prioritize wrapping scalar time-dependent variables inside a fire.Constant for better performance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get the issue here. The variable "q_t" is defined as a Constant is line 14. This should be faster than before because the method was returning new constants in every call. As far as I remember, using the same constant avoids some parts of JIT or assembling.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see now. I mistakenly thought that q_t was defined as a Function rather than a Constant, I didn't see line 14. Thanks for clarifying. Yes, reusing the same constant does indeed optimize the process by avoiding the overhead associated with redefining it repeatedly, leading to better performance.

15 changes: 0 additions & 15 deletions spyro/solvers/time_integration.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,4 @@
from .time_integration_central_difference import central_difference
from .time_integration_central_difference import central_difference_MMS


def time_integrator(Wave_object, source_id=0):
if Wave_object.source_type == "ricker":
return time_integrator_ricker(Wave_object, source_id=source_id)
elif Wave_object.source_type == "MMS":
return time_integrator_mms(Wave_object, source_id=source_id)


def time_integrator_ricker(Wave_object, source_id=0):
return central_difference(Wave_object, source_id=source_id)

def time_integrator_mms(Wave_object, source_id=0):
if Wave_object.time_integrator == "central_difference":
return central_difference_MMS(Wave_object, source_id=source_id)
else:
raise ValueError("The time integrator specified is not implemented yet")
129 changes: 13 additions & 116 deletions spyro/solvers/time_integration_central_difference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import firedrake as fire
from firedrake import Constant, dx, dot, grad

from ..io.basicio import parallel_print
from . import helpers
Expand All @@ -22,7 +21,9 @@ def central_difference(wave, source_id=0):
tuple:
A tuple containing the forward solution and the receiver output.
"""
wave.sources.current_source = source_id
if wave.sources is not None:
wave.sources.current_source = source_id
rhs_forcing = fire.Function(wave.function_space)

filename, file_extension = wave.forward_output_file.split(".")
output_filename = filename + "sn" + str(source_id) + "." + file_extension
Expand All @@ -35,7 +36,6 @@ def central_difference(wave, source_id=0):
t = wave.current_time
nt = int(wave.final_time / wave.dt) + 1 # number of timesteps

rhs_forcing = fire.Function(wave.function_space)
usol = [
fire.Function(wave.function_space, name=wave.get_function_name())
for t in range(nt)
Expand All @@ -44,11 +44,16 @@ def central_difference(wave, source_id=0):
usol_recv = []
save_step = 0
for step in range(nt):
rhs_forcing.assign(0.0)
# Basic way of applying sources
wave.update_source_expression(t)
fire.assemble(wave.rhs, tensor=wave.B)
f = wave.sources.apply_source(rhs_forcing, wave.wavelet[step])
B0 = wave.B.sub(0)
B0 += f

# More efficient way of applying sources
if wave.sources is not None:
rhs_forcing.assign(0.0)
f = wave.sources.apply_source(rhs_forcing, wave.wavelet[step])
B0 = wave.B.sub(0)
B0 += f

wave.solver.solve(wave.next_vstate, wave.B)

Expand Down Expand Up @@ -86,112 +91,4 @@ def central_difference(wave, source_id=0):
wave.forward_solution = usol
wave.forward_solution_receivers = usol_recv

return usol, usol_recv

def central_difference_MMS(Wave_object, source_id=0):
"""Propagates the wave forward in time.
Currently uses central differences.

Parameters:
-----------
dt: Python 'float' (optional)
Time step to be used explicitly. If not mentioned uses the default,
that was estabilished in the model_parameters.
final_time: Python 'float' (optional)
Time which simulation ends. If not mentioned uses the default,
that was estabilished in the model_parameters.
"""
receivers = Wave_object.receivers
comm = Wave_object.comm
temp_filename = Wave_object.forward_output_file
filename, file_extension = temp_filename.split(".")
output_filename = filename + "sn_mms_" + "." + file_extension
if Wave_object.forward_output:
print(f"Saving output in: {output_filename}", flush=True)

output = fire.File(output_filename, comm=comm.comm)
comm.comm.barrier()

X = fire.Function(Wave_object.function_space)

final_time = Wave_object.final_time
dt = Wave_object.dt
t = Wave_object.current_time
nt = int((final_time - t) / dt) + 1 # number of timesteps

u_nm1 = Wave_object.u_nm1
u_n = Wave_object.u_n
u_nm1.assign(Wave_object.analytical_solution(t - 2 * dt))
u_n.assign(Wave_object.analytical_solution(t - dt))
u_np1 = fire.Function(Wave_object.function_space, name="pressure t +dt")
u = fire.TrialFunction(Wave_object.function_space)
v = fire.TestFunction(Wave_object.function_space)

usol = [
fire.Function(Wave_object.function_space, name="pressure")
for t in range(nt)
if t % Wave_object.gradient_sampling_frequency == 0
]
usol_recv = []
save_step = 0
B = Wave_object.B
rhs = Wave_object.rhs
quad_rule = Wave_object.quadrature_rule

q_xy = Wave_object.q_xy

for step in range(nt):
q = q_xy * Wave_object.mms_source_in_time(t)
m1 = (
1
/ (Wave_object.c * Wave_object.c)
* ((u - 2.0 * u_n + u_nm1) / Constant(dt**2))
* v
* dx(scheme=quad_rule)
)
a = dot(grad(u_n), grad(v)) * dx(scheme=quad_rule)
le = q * v * dx(scheme=quad_rule)

form = m1 + a - le
rhs = fire.rhs(form)

B = fire.assemble(rhs, tensor=B)

Wave_object.solver.solve(X, B)

u_np1.assign(X)

usol_recv.append(
Wave_object.receivers.interpolate(u_np1.dat.data_ro_with_halos[:])
)

if step % Wave_object.gradient_sampling_frequency == 0:
usol[save_step].assign(u_np1)
save_step += 1

if (step - 1) % Wave_object.output_frequency == 0:
assert (
fire.norm(u_n) < 1
), "Numerical instability. Try reducing dt or building the " \
"mesh differently"
if Wave_object.forward_output:
output.write(u_n, time=t, name="Pressure")
if t > 0:
helpers.display_progress(Wave_object.comm, t)

u_nm1.assign(u_n)
u_n.assign(u_np1)

t = step * float(dt)

Wave_object.current_time = t
helpers.display_progress(Wave_object.comm, t)
Wave_object.analytical_solution(t)

usol_recv = helpers.fill(
usol_recv, receivers.is_local, nt, receivers.number_of_points
)
usol_recv = utils.utils.communicate(usol_recv, comm)
Wave_object.receivers_output = usol_recv

return usol, usol_recv
return usol, usol_recv
9 changes: 8 additions & 1 deletion spyro/solvers/wave.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the _get_initial_velocity_model to _initialize_model_parameters because the elastic wave propagators have more parameters. I also moved the implementation to AcousticWave to allow ElasticWave to inherit Wave behavior consistently.

Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(self, dictionary=None, comm=None):
)
else:
warnings.warn("No mesh found. Please define a mesh.")
# Expression to define sources through UFL (less efficient)
self.source_expression = None

@abstractmethod
def forward_solve(self):
Expand Down Expand Up @@ -337,4 +339,9 @@ def get_function(self):
def get_function_name(self):
'''Returns the string representing the function of the wave object
(e.g., "pressure" or "displacement")'''
pass
pass

def update_source_expression(self, t):
'''Update the source expression during wave propagation. This method must be
implemented only by subclasses that make use of the source term'''
pass
2 changes: 1 addition & 1 deletion test/test_MMS.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run_solve(model):
Wave_obj.set_initial_velocity_model(expression="1 + sin(pi*-z)*sin(pi*x)")
Wave_obj.forward_solve()

u_an = Wave_obj.analytical
u_an = Wave_obj.analytical_solution(Wave_obj.current_time)
u_num = Wave_obj.u_n

return errornorm(u_num, u_an)
Expand Down
Loading