Skip to content

Commit

Permalink
cleaned up SourceTFR + SourceTFR tests...
Browse files Browse the repository at this point in the history
...according to review in PR mne-tools#6543

Signed-off-by: Dirk Gütlin <dirk.guetlin@stud.sbg.ac.at>
  • Loading branch information
DiGyt committed Aug 5, 2019
1 parent 33dc256 commit 56c83fa
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 94 deletions.
120 changes: 73 additions & 47 deletions mne/source_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np

from .filter import resample
from .utils import _check_subject, verbose, _time_mask, _check_option
from .utils import (_check_subject, verbose, _time_mask, _check_option,
_validate_type)
from .io.base import ToDataFrameMixin, TimeMixin
from .externals.h5io import write_hdf5

Expand All @@ -27,13 +28,31 @@ class SourceTFR(ToDataFrameMixin, TimeMixin):
Time point of the first sample in data.
tstep : float
Time step between successive samples in data.
freqs : ndarray, shape (n_freqs,)
The frequencies in Hz.
dims : tuple, default ("dipoles", "freqs", "times")
The dimension names of the data, where each element of the tuple
corresponds to one axis of the data field. Allowed values are:
("dipoles", "freqs", "times"), ("dipoles", "epochs", "freqs",
"times"), ("dipoles", "orientations", "freqs", "times"), ("dipoles",
"orientations", "epochs", "freqs", "times").
method : str | None, default None
Comment on the method used to compute the data, as a combination of
the used method and the compued product (e.g. "morlet-power" or
"stockwell-itc").
subject : str | None
The subject name. While not necessary, it is safer to set the
subject parameter to avoid analysis errors.
%(verbose)s
Attributes
----------
freqs : ndarray, shape (n_freqs,)
The frequencies in Hz.
method : str | None, default None
Comment on the method used to compute the data, as a combination of
the used method and the compued product (e.g. "morlet-power" or
"stockwell-itc").
subject : str | None
The subject name.
times : array, shape (n_times,)
Expand All @@ -43,6 +62,8 @@ class SourceTFR(ToDataFrameMixin, TimeMixin):
be an array if there is only one source space (e.g., for volumes).
data : array of shape (n_dipoles, n_times)
The data in source space.
dims : tuple
The dimension names corresponding to the data.
shape : tuple
The shape of the data. A tuple of int (n_dipoles, n_times).
"""
Expand All @@ -64,20 +85,43 @@ def __init__(self, data, vertices=None, tmin=None, tstep=None, freqs=None,
_check_option("dims", list(dims),
[list(v_dims) for v_dims in valid_dims])
_check_option("method", method, valid_methods)
_validate_type(vertices, (np.ndarray, list), "vertices")

if not (isinstance(vertices, np.ndarray) or
isinstance(vertices, list)):
raise ValueError('Vertices must be a numpy array or a list of '
'arrays')
data, kernel, sens_data, vertices = self._prepare_data(data, vertices,
dims)

self.dims = dims
self.vertices = vertices
self.method = method
self.freqs = freqs
self.verbose = verbose
self.subject = _check_subject(None, subject, False)

# TODO: src_type should rather represent the stc source type
self._src_type = 'SourceTFR'
self._data_ndim = len(dims)
self._data = data
self._kernel = kernel
self._sens_data = sens_data
self._kernel_removed = False
self._tmin = tmin
self._tstep = tstep
self._times = None
self._update_times()

self.dims = dims
self.method = method
self.freqs = freqs

def __repr__(self): # noqa: D105
s = "{} vertices".format((sum(len(v) for v in self._vertices_list),))
if self.subject is not None:
s += ", subject : {}".format(self.subject)
s += ", tmin : {} (ms)".format(1e3 * self.tmin)
s += ", tmax : {} (ms)".format(1e3 * self.times[-1])
s += ", tstep : {} (ms)".format(1e3 * self.tstep)
s += ", data shape : {}".format(self.shape)
return "<{0} | {1}>".format(type(self).__name__, s)

def _prepare_data(self, data, vertices, dims):
"""Prepare the data for the SourceTFR init."""
kernel, sens_data = None, None
if isinstance(data, tuple):
if len(data) != 2:
Expand All @@ -87,9 +131,9 @@ def __init__(self, data, vertices=None, tmin=None, tstep=None, freqs=None,
if kernel.shape[1] != sens_data.shape[0]:
raise ValueError('kernel and sens_data have invalid '
'dimensions')
if sens_data.ndim != self._data_ndim:
raise ValueError('The sensor data must have %s dimensions, got'
' %s' % (self._data_ndim, sens_data.ndim,))
if sens_data.ndim != len(dims):
raise ValueError('The sensor data must have {0} dimensions, '
'got {1}'.format(len(dims), sens_data.ndim, ))
# TODO: Make sure this is supported
if 'orientations' in dims:
raise ValueError('Multiple orientations are not supported for '
Expand All @@ -111,40 +155,21 @@ def __init__(self, data, vertices=None, tmin=None, tstep=None, freqs=None,
# safeguard the user against doing something silly
if data is not None:
if data.shape[0] != n_src:
raise ValueError('Number of vertices (%i) and stfr.shape[0] '
'(%i) must match' % (n_src, data.shape[0]))
if data.ndim != self._data_ndim:
raise ValueError('Number of vertices ({0}) and stfr.shape[0] '
'({1}) must match'.format(n_src,
data.shape[0]))
if data.ndim != len(dims):
raise ValueError('Data (shape {0}) must have {1} dimensions '
'for SourceTFR with dims={2}'
.format(data.shape, self._data_ndim,
self.dims))
.format(data.shape, len(dims),
dims))

if "orientations" in dims and data.shape[1] != 3:
raise ValueError('If multiple orientations are defined, '
'stfr.shape[1] must be 3. Got '
'shape[1] == {}'.format(data.shape[1]))

self._data = data
self._tmin = tmin
self._tstep = tstep
self.vertices = vertices
self.verbose = verbose
self._kernel = kernel
self._sens_data = sens_data
self._kernel_removed = False
self._times = None
self._update_times()
self.subject = _check_subject(None, subject, False)

def __repr__(self): # noqa: D105
s = "%d vertices" % (sum(len(v) for v in self._vertices_list),)
if self.subject is not None:
s += ", subject : %s" % self.subject
s += ", tmin : %s (ms)" % (1e3 * self.tmin)
s += ", tmax : %s (ms)" % (1e3 * self.times[-1])
s += ", tstep : %s (ms)" % (1e3 * self.tstep)
s += ", data shape : %s" % (self.shape,)
return "<%s | %s>" % (type(self).__name__, s)
return data, kernel, sens_data, vertices

@property
def _vertices_list(self):
Expand All @@ -164,11 +189,12 @@ def save(self, fname, ftype='h5', verbose=None):
File format to use. Currently, the only allowed values is "h5".
%(verbose_meth)s
"""

# this message looks more informative to me than _check_option().
if ftype != 'h5':
raise ValueError('%s objects can only be written as HDF5 files.'
% (self.__class__.__name__,))
if not fname.endswith('.h5'):
fname += '-stfr.h5'
raise ValueError('{} objects can only be written as HDF5 files.'
.format(self.__class__.__name__, ))
fname = fname if fname.endswith('h5') else '{}-stfr.h5'.format(fname)
write_hdf5(fname,
dict(vertices=self.vertices, data=self.data, tmin=self.tmin,
tstep=self.tstep, subject=self.subject,
Expand Down Expand Up @@ -272,8 +298,8 @@ def data(self):
def data(self, value):
value = np.asarray(value)
if self._data is not None and value.ndim != self._data.ndim:
raise ValueError('Data array should have %d dimensions.' %
self._data.ndim)
raise ValueError('Data array should have {} dimensions.'
.format(self._data.ndim))

# vertices can be a single number, so cast to ndarray
if isinstance(self.vertices, list):
Expand All @@ -285,8 +311,8 @@ def data(self, value):

if value.shape[0] != n_verts:
raise ValueError('The first dimension of the data array must '
'match the number of vertices (%d != %d)' %
(value.shape[0], n_verts))
'match the number of vertices ({0} != {1})'
.format(value.shape[0], n_verts))

self._data = value
self._update_times()
Expand Down Expand Up @@ -329,8 +355,8 @@ def times(self):

@times.setter
def times(self, value):
raise ValueError('You cannot write to the .times attribute directly. '
'This property automatically updates whenever '
raise RuntimeError('You cannot write to the .times attribute directly.'
' This property automatically updates whenever '
'.tmin, .tstep or .data changes.')

def _update_times(self):
Expand Down
Loading

0 comments on commit 56c83fa

Please sign in to comment.