Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions devito/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def __init__(self, *args, **kwargs):
return
else:
super(TimeData, self).__init__(*args, **kwargs)
self._full_data = self._data.view() if self._data else None
time_dim = kwargs.get('time_dim')
self.time_order = kwargs.get('time_order', 1)
self.save = kwargs.get('save', False)
Expand Down Expand Up @@ -295,9 +296,39 @@ def _allocate_memory(self):
"""function to allocate memmory in terms of numpy ndarrays."""
super(TimeData, self)._allocate_memory()

self._full_data = self._data.view()

if self.pad_time:
self._data = self._data[self.time_order:, :, :]

def init_data(self, timestep, data):
"""Function to initialize the initial time steps

:param timestep: Time step to initialize.
Must be negative since calculated timesteps start from 0.
:param data: :class:`numpy.ndarray` containing the initial spatial data
"""
if self._full_data is None:
self._allocate_memory()

assert timestep < 0, "Timestep must be negative"
assert data.shape == self._full_data[0].shape, \
"Data must have the same shape as the spatial data"

# Adds the time_order to the index to access padded indexes
timestep += self.time_order
self._full_data[timestep] = data

def get_data(self, timestep=0):
"""Returns the calculated data at the specified timestep

:param timestep: The timestep from which we want to retrieve the data.
Specify only in the case :obj:`self.save` is True
"""
timestep += self.time_order

return self._full_data[timestep, :]

@property
def dim(self):
"""Returns the spatial dimension of the data object"""
Expand Down
44 changes: 44 additions & 0 deletions tests/test_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
from sympy import Eq, solve, symbols

from devito.interfaces import TimeData
from devito.operator import Operator


def initial(dx=0.01, dy=0.01):
nx, ny = int(1 / dx), int(1 / dy)
xx, yy = np.meshgrid(np.linspace(0., 1., nx, dtype=np.float32),
np.linspace(0., 1., ny, dtype=np.float32))
ui = np.zeros((nx, ny), dtype=np.float32)
r = (xx - .5)**2. + (yy - .5)**2.
ui[np.logical_and(.05 <= r, r <= .1)] = 1.

return ui


def run_simulation(save=False, dx=0.01, dy=0.01, a=0.5, timesteps=100):
nx, ny = int(1 / dx), int(1 / dy)
dx2, dy2 = dx**2, dy**2
dt = dx2 * dy2 / (2 * a * (dx2 + dy2))

u = TimeData(
name='u', shape=(nx, ny), time_dim=timesteps,
time_order=1, space_order=2, save=save, pad_time=save
)
u.init_data(-1, initial())

a, h, s = symbols('a h s')
eqn = Eq(u.dt, a * (u.dx2 + u.dy2))
stencil = solve(eqn, u.forward)[0]
op = Operator(stencils=Eq(u.forward, stencil), substitutions={a: 0.5, h: dx, s: dt},
nt=timesteps, shape=(nx, ny), spc_border=1, time_order=1)
op.apply()

if save:
return u.get_data(timesteps - 2)
else:
return u.get_data()


def test_save():
assert(np.array_equal(run_simulation(True), run_simulation()))