diff --git a/devito/interfaces.py b/devito/interfaces.py index 81b82823ae..05643c4ddb 100644 --- a/devito/interfaces.py +++ b/devito/interfaces.py @@ -265,6 +265,7 @@ def __init__(self, *args, **kwargs): return else: super(TimeData, self).__init__(*args, **kwargs) + self._full_data = self._data.view() if self._data else None time_dim = kwargs.get('time_dim') self.time_order = kwargs.get('time_order', 1) self.save = kwargs.get('save', False) @@ -295,9 +296,39 @@ def _allocate_memory(self): """function to allocate memmory in terms of numpy ndarrays.""" super(TimeData, self)._allocate_memory() + self._full_data = self._data.view() + if self.pad_time: self._data = self._data[self.time_order:, :, :] + def init_data(self, timestep, data): + """Function to initialize the initial time steps + + :param timestep: Time step to initialize. + Must be negative since calculated timesteps start from 0. + :param data: :class:`numpy.ndarray` containing the initial spatial data + """ + if self._full_data is None: + self._allocate_memory() + + assert timestep < 0, "Timestep must be negative" + assert data.shape == self._full_data[0].shape, \ + "Data must have the same shape as the spatial data" + + # Adds the time_order to the index to access padded indexes + timestep += self.time_order + self._full_data[timestep] = data + + def get_data(self, timestep=0): + """Returns the calculated data at the specified timestep + + :param timestep: The timestep from which we want to retrieve the data. + Specify only in the case :obj:`self.save` is True + """ + timestep += self.time_order + + return self._full_data[timestep, :] + @property def dim(self): """Returns the spatial dimension of the data object""" diff --git a/tests/test_save.py b/tests/test_save.py new file mode 100644 index 0000000000..05cc96db49 --- /dev/null +++ b/tests/test_save.py @@ -0,0 +1,44 @@ +import numpy as np +from sympy import Eq, solve, symbols + +from devito.interfaces import TimeData +from devito.operator import Operator + + +def initial(dx=0.01, dy=0.01): + nx, ny = int(1 / dx), int(1 / dy) + xx, yy = np.meshgrid(np.linspace(0., 1., nx, dtype=np.float32), + np.linspace(0., 1., ny, dtype=np.float32)) + ui = np.zeros((nx, ny), dtype=np.float32) + r = (xx - .5)**2. + (yy - .5)**2. + ui[np.logical_and(.05 <= r, r <= .1)] = 1. + + return ui + + +def run_simulation(save=False, dx=0.01, dy=0.01, a=0.5, timesteps=100): + nx, ny = int(1 / dx), int(1 / dy) + dx2, dy2 = dx**2, dy**2 + dt = dx2 * dy2 / (2 * a * (dx2 + dy2)) + + u = TimeData( + name='u', shape=(nx, ny), time_dim=timesteps, + time_order=1, space_order=2, save=save, pad_time=save + ) + u.init_data(-1, initial()) + + a, h, s = symbols('a h s') + eqn = Eq(u.dt, a * (u.dx2 + u.dy2)) + stencil = solve(eqn, u.forward)[0] + op = Operator(stencils=Eq(u.forward, stencil), substitutions={a: 0.5, h: dx, s: dt}, + nt=timesteps, shape=(nx, ny), spc_border=1, time_order=1) + op.apply() + + if save: + return u.get_data(timesteps - 2) + else: + return u.get_data() + + +def test_save(): + assert(np.array_equal(run_simulation(True), run_simulation()))