Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 13 additions & 38 deletions devito/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class DenseData(SymbolicData):
:param shape: Shape of the spatial data grid
:param dtype: Data type of the buffered data
:param space_order: Discretisation order for space derivatives
:param initializer: Function to initialize the data, optional

Note: :class:`DenseData` objects are assumed to be constant in time and
therefore do not support time derivatives. Use :class:`TimeData` for
Expand All @@ -119,8 +120,11 @@ def __init__(self, *args, **kwargs):
self.shape = kwargs.get('shape')
self.dtype = kwargs.get('dtype', np.float32)
self.space_order = kwargs.get('space_order', 1)
initializer = kwargs.get('initializer', None)
if initializer is not None:
assert(callable(initializer))
self.initializer = initializer
self._data = kwargs.get('_data', None)
self.initializer = None
MemmapManager.setup(self, *args, **kwargs)
# Store new instance in symbol cache
self._cache_put(self)
Expand Down Expand Up @@ -155,15 +159,6 @@ def indexify(self):

return self.indexed[indices]

def set_initializer(self, lambda_initializer):
"""Set data intialising function to given lambda function.

:param lambda_initializer: Given lambda function.
"""
assert(callable(lambda_initializer))

self.initializer = lambda_initializer

def _allocate_memory(self):
"""Function to allocate memmory in terms of numpy ndarrays.

Expand Down Expand Up @@ -276,6 +271,8 @@ class TimeData(DenseData):
:param dtype: Data type of the buffered data
:param save: Save the intermediate results to the data buffer. Defaults
to `False`, indicating the use of alternating buffers.
:param pad_time: Set to `True` if save is True and you want to initialize
the first :obj:`time_order` timesteps.
:param time_dim: Size of the time dimension that dictates the leading
dimension of the data buffer if :param save: is True.
:param time_order: Order of the time discretization which affects the
Expand Down Expand Up @@ -311,6 +308,12 @@ def __init__(self, *args, **kwargs):
# Store final instance in symbol cache
self._cache_put(self)

def initialize(self):
if self.initializer is not None:
if self._full_data is None:
self._allocate_memory()
self.initializer(self._full_data)

@classmethod
def indices(cls, shape):
"""Return the default dimension indices for a given data shape
Expand All @@ -331,34 +334,6 @@ def _allocate_memory(self):
if self.pad_time:
self._data = self._data[self.time_order:, :, :]

def init_data(self, timestep, data):
"""Function to initialize the initial time steps

:param timestep: Time step to initialize.
Must be negative since calculated timesteps start from 0.
:param data: :class:`numpy.ndarray` containing the initial spatial data
"""
if self._full_data is None:
self._allocate_memory()

assert timestep < 0, "Timestep must be negative"
assert data.shape == self._full_data[0].shape, \
"Data must have the same shape as the spatial data"

# Adds the time_order to the index to access padded indexes
timestep += self.time_order
self._full_data[timestep] = data

def get_data(self, timestep=0):
"""Returns the calculated data at the specified timestep

:param timestep: The timestep from which we want to retrieve the data.
Specify only in the case :obj:`self.save` is True
"""
timestep += self.time_order

return self._full_data[timestep, :]

@property
def dim(self):
"""Returns the spatial dimension of the data object"""
Expand Down
11 changes: 7 additions & 4 deletions tests/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@ def initial(dx=0.01, dy=0.01):
return ui


def initializer(data):
data[0, :] = initial()


def run_simulation(save=False, dx=0.01, dy=0.01, a=0.5, timesteps=100):
nx, ny = int(1 / dx), int(1 / dy)
dx2, dy2 = dx**2, dy**2
dt = dx2 * dy2 / (2 * a * (dx2 + dy2))

u = TimeData(
name='u', shape=(nx, ny), time_dim=timesteps,
name='u', shape=(nx, ny), time_dim=timesteps, initializer=initializer,
time_order=1, space_order=2, save=save, pad_time=save
)
u.init_data(-1, initial())

a, h, s = symbols('a h s')
eqn = Eq(u.dt, a * (u.dx2 + u.dy2))
Expand All @@ -35,9 +38,9 @@ def run_simulation(save=False, dx=0.01, dy=0.01, a=0.5, timesteps=100):
op.apply()

if save:
return u.get_data(timesteps - 2)
return u.data[timesteps - 1, :]
else:
return u.get_data()
return u.data[timesteps % 2, :]


def test_save():
Expand Down