diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 58575f16262..b3fc74a2f02 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -15,6 +15,8 @@ Current (0.19.dev0) Changelog ~~~~~~~~~ +- Add :class:`mne.SourceTFR` class, a container for time frequency transformed source level data by `Dirk Gütlin`_ + - Add support for making epochs with duplicated events, by allowing three policies: "error" (default), "drop", or "merge" in :class:`mne.Epochs` by `Stefan Appelhoff`_ - Add :func:`mne.channels.make_dig_montage` to create :class:`mne.channels.DigMontage` objects out of np.arrays by `Joan Massich`_ diff --git a/doc/conf.py b/doc/conf.py index 634e66faaff..fb04dee43f9 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -562,7 +562,7 @@ def reset_warnings(gallery_conf, fname): 'FilterEstimator': 'mne.decoding.FilterEstimator', 'EMS': 'mne.decoding.EMS', 'CSP': 'mne.decoding.CSP', 'Beamformer': 'mne.beamformer.Beamformer', - 'Transform': 'mne.transforms.Transform', + 'Transform': 'mne.transforms.Transform', 'SourceTFR': 'mne.SourceTFR' } numpydoc_xref_ignore = { # words diff --git a/doc/python_reference.rst b/doc/python_reference.rst index 89b40474818..cc05acda7ae 100644 --- a/doc/python_reference.rst +++ b/doc/python_reference.rst @@ -708,6 +708,7 @@ Source Space Data VectorSourceEstimate VolSourceEstimate VolVectorSourceEstimate + SourceTFR SourceMorph compute_source_morph head_to_mni diff --git a/mne/__init__.py b/mne/__init__.py index 43eb9af240c..bb83e806d73 100644 --- a/mne/__init__.py +++ b/mne/__init__.py @@ -72,6 +72,7 @@ add_source_space_distances, morph_source_spaces, get_volume_labels_from_aseg, get_volume_labels_from_src) +from .source_tfr import SourceTFR from .annotations import Annotations, read_annotations, events_from_annotations from .epochs import (BaseEpochs, Epochs, EpochsArray, read_epochs, concatenate_epochs) diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index fe313167360..d38a25ec37c 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -1071,12 +1071,19 @@ def apply_inverse_raw(raw, inverse_operator, lambda2, method="dSPM", def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM', label=None, nave=1, pick_ori=None, prepared=False, method_params=None, - verbose=None): + delayed=False, verbose=None): """Generate inverse solutions for epochs. Used in apply_inverse_epochs.""" _check_option('method', method, INVERSE_METHODS) _check_ori(pick_ori, inverse_operator['source_ori']) _check_ch_names(inverse_operator, epochs.info) + is_free_ori = not (is_fixed_orient(inverse_operator) or + pick_ori == 'normal') + + if delayed and is_free_ori and pick_ori != "vector": + raise ValueError("delayed must be False for free orientations other " + "than pick_ori='vector'.") + # # Set up the inverse according to the parameters # @@ -1095,13 +1102,10 @@ def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM', tstep = 1.0 / epochs.info['sfreq'] tmin = epochs.times[0] - is_free_ori = not (is_fixed_orient(inverse_operator) or - pick_ori == 'normal') - if pick_ori == 'vector' and noise_norm is not None: noise_norm = noise_norm.repeat(3, axis=0) - if not is_free_ori and noise_norm is not None: + if not (is_free_ori and pick_ori != 'vector') and noise_norm is not None: # premultiply kernel with noise normalization K *= noise_norm @@ -1116,15 +1120,16 @@ def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM', # Compute solution and combine current components (non-linear) sol = np.dot(K, e[sel]) # apply imaging kernel - if pick_ori != 'vector': - logger.info('combining the current components...') - sol = combine_xyz(sol) + if is_free_ori and pick_ori != 'vector': + logger.info('combining the current components...') + sol = combine_xyz(sol) if noise_norm is not None: sol *= noise_norm + else: # Linear inverse: do computation here or delayed - if len(sel) < K.shape[1]: + if delayed: sol = (K, e[sel]) else: sol = np.dot(K, e[sel]) @@ -1143,7 +1148,7 @@ def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM', def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM", label=None, nave=1, pick_ori=None, return_generator=False, prepared=False, - method_params=None, verbose=None): + method_params=None, delayed=False, verbose=None): """Apply inverse operator to Epochs. Parameters @@ -1179,6 +1184,19 @@ def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM", Additional options for eLORETA. See Notes of :func:`apply_inverse`. .. versionadded:: 0.16 + delayed : bool + If False, the source time courses are computed. If True, they are + stored as a tuple of two smaller arrays in order to save memory. In + this case, the first array in the tuple corresponds to the "kernel" + shape (n_vertices [, n_orientations], n_sensors) and the second array + to the "sens_data" shape (n_sensors, n_times). The full source time + courses field will be automatically computed when stc.data is called + for the first time (see for example: :class:`mne.SourceEstimate`). + `delayed=True` is only implemented for fixed orientations (e.g. + from pick_ori = "normal") as well as pick_ori="vector". + Defaults to False. + + .. versionadded:: 0.19 %(verbose)s Returns @@ -1194,7 +1212,7 @@ def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM", stcs = _apply_inverse_epochs_gen( epochs, inverse_operator, lambda2, method=method, label=label, nave=nave, pick_ori=pick_ori, verbose=verbose, prepared=prepared, - method_params=method_params) + method_params=method_params, delayed=delayed) if not return_generator: # return a list diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index 452e4e325a4..5e78e28696f 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -783,12 +783,9 @@ def test_apply_mne_inverse_fixed_raw(): assert_array_almost_equal(stc.data, stc3.data) -@testing.requires_testing_data -def test_apply_mne_inverse_epochs(): - """Test MNE with precomputed inverse operator on Epochs.""" - inverse_operator = read_inverse_operator(fname_full) - label_lh = read_label(fname_label % 'Aud-lh') - label_rh = read_label(fname_label % 'Aud-rh') +@pytest.fixture +def epochs(): + """Create an epochs object used for testing.""" event_id, tmin, tmax = 1, -0.2, 0.5 raw = read_raw_fif(fname_raw) @@ -801,6 +798,16 @@ def test_apply_mne_inverse_epochs(): epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject, flat=flat) + return epochs + + +@testing.requires_testing_data +def test_apply_mne_inverse_epochs(epochs): + """Test MNE with precomputed inverse operator on Epochs.""" + inverse_operator = read_inverse_operator(fname_full) + label_lh = read_label(fname_label % 'Aud-lh') + label_rh = read_label(fname_label % 'Aud-rh') + inverse_operator = prepare_inverse_operator(inverse_operator, nave=1, lambda2=lambda2, method="dSPM") @@ -893,4 +900,37 @@ def test_inverse_ctf_comp(): apply_inverse_raw(raw, inv, 1. / 9.) +def _check_delayed_data(inst, delayed): + """Check whether data is represented as kernel or not.""" + if delayed: + assert isinstance(inst._kernel, np.ndarray) + assert isinstance(inst._sens_data, np.ndarray) + assert inst._data is None + assert not inst._kernel_removed + else: + assert inst._kernel is None + assert inst._sens_data is None + assert isinstance(inst._data, np.ndarray) + + +@testing.requires_testing_data +@pytest.mark.parametrize('pick_ori', ['normal', 'vector']) +def test_delayed_data(epochs, pick_ori): + """Test if kernel in apply_inverse_epochs was properly applied.""" + inverse_operator = read_inverse_operator(fname_full) + inverse_operator = prepare_inverse_operator(inverse_operator, nave=1, + lambda2=lambda2, + method="dSPM") + + full_stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, + pick_ori=pick_ori, delayed=False) + kernel_stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, + pick_ori=pick_ori, delayed=True) + + for full_stc, kern_stc in zip(full_stcs, kernel_stcs): + _check_delayed_data(full_stc, delayed=False) + _check_delayed_data(kern_stc, delayed=True) + assert_allclose(kern_stc.data, full_stc.data) + + run_tests_if_main() diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 5af0b43409b..735c7e86588 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -404,6 +404,12 @@ def guess_src_type(): raise ValueError('vertices has to be either a list with one or more ' 'arrays or an array') + if isinstance(data, tuple): + data, sens_data = data[0], data[1] + is_kernel = True + else: + is_kernel = False + # massage the data if src_type == 'surface' and vector: n_vertices = len(vertices[0]) + len(vertices[1]) @@ -416,9 +422,12 @@ def guess_src_type(): else: pass # noqa - return Klass( - data=data, vertices=vertices, tmin=tmin, tstep=tstep, subject=subject - ) + if is_kernel: + return Klass(data=(data, sens_data), vertices=vertices, tmin=tmin, + tstep=tstep, subject=subject) + else: + return Klass(data=data, vertices=vertices, tmin=tmin, tstep=tstep, + subject=subject) def _verify_source_estimate_compat(a, b): @@ -485,7 +494,7 @@ def __init__(self, data, vertices=None, tmin=None, tstep=None, raise ValueError('If data is a tuple it has to be length 2') kernel, sens_data = data data = None - if kernel.shape[1] != sens_data.shape[0]: + if kernel.shape[-1] != sens_data.shape[0]: raise ValueError('kernel and sens_data have invalid ' 'dimensions') if sens_data.ndim != 2: diff --git a/mne/source_tfr.py b/mne/source_tfr.py new file mode 100644 index 00000000000..c46b03fa738 --- /dev/null +++ b/mne/source_tfr.py @@ -0,0 +1,483 @@ +# -*- coding: utf-8 -*- +# +# Authors: Dirk Gütlin +# Joan Massich +# +# License: BSD (3-clause) + +import copy +import numpy as np + +from .filter import resample +from .utils import (_check_subject, verbose, fill_doc, _time_mask, + _freq_mask, _check_option, _validate_type) +from .io.base import ToDataFrameMixin, TimeMixin +from .externals.h5io import write_hdf5 +from .source_estimate import (SourceEstimate, VectorSourceEstimate, + VolSourceEstimate) +from .viz import (plot_source_estimates, plot_vector_source_estimates, + plot_volume_source_estimates) + + +@fill_doc +class SourceTFR(ToDataFrameMixin, TimeMixin): + """Class for time-frequency transformed source level data. + + Parameters + ---------- + data : array | tuple, shape (2,) + Time-frequency transformed data in source space. The data can either + be a single array of shape(n_dipoles[, n_orientations][, n_epochs], + n_freqs, n_times) or a tuple with two arrays: "kernel" shape + (n_dipoles, n_sensors) and "sens_data" shape (n_sensors, n_freqs, + n_times). In this case, the source space data corresponds to + "numpy.dot(kernel, sens_data)". + vertices : array | list of array + Vertex numbers corresponding to the data. + tmin : float + Time point of the first sample in data. + tstep : float + Time step between successive samples in data. + freqs : ndarray, shape (n_freqs,) + The frequencies in Hz. + dims : tuple, default ("dipoles", "freqs", "times") + The dimension names of the data, where each element of the tuple + corresponds to one axis of the data field. Allowed values are: + ("dipoles", "freqs", "times"), ("dipoles", "epochs", "freqs", + "times"), ("dipoles", "orientations", "freqs", "times"), ("dipoles", + "orientations", "epochs", "freqs", "times"). + method : str | None, default None + Comment on the method used to compute the data, as a combination of + the used method and the compued product (e.g. "morlet-power" or + "stockwell-itc"). + subject : str | None + The subject name. While not necessary, it is safer to set the + subject parameter to avoid analysis errors. + src_type : str, default "surface" + The source type of the object. Can be "surface" or "volume". + %(verbose)s + + Attributes + ---------- + freqs : ndarray, shape (n_freqs,) + The frequencies in Hz. + method : str | None, default None + Comment on the method used to compute the data, as a combination of + the used method and the compued product (e.g. "morlet-power" or + "stockwell-itc"). + subject : str | None + The subject name. + times : array, shape (n_times,) + The time vector. + data : array of shape (n_dipoles, n_times) + The data in source space. + dims : tuple + The dimension names corresponding to the data. + shape : tuple + The shape of the data. A tuple of int (n_dipoles, n_times). + """ + + @verbose + def __init__(self, data, vertices=None, tmin=None, tstep=None, freqs=None, + dims=("dipoles", "freqs", "times"), method=None, subject=None, + src_type="surface", verbose=None): # noqa: D102 + + valid_dims = [("dipoles", "freqs", "times"), + ("dipoles", "epochs", "freqs", "times"), + ("dipoles", "orientations", "freqs", "times"), + ("dipoles", "orientations", "epochs", "freqs", "times")] + + valid_methods = ["morlet-power", "multitaper-power", "stockwell-power", + "morlet-itc", "multitaper-itc", "stockwell-itc", None] + + # unfortunately, _check option does not work with the original tuples + _check_option("dims", list(dims), + [list(v_dims) for v_dims in valid_dims]) + _check_option("method", method, valid_methods) + _check_option("src_type", src_type, ["surface", "volume"]) + _validate_type(vertices, (np.ndarray, list), "vertices") + + data, kernel, sens_data, vertices = _prepare_data(data, vertices, dims) + + self.dims = dims + self.method = method + self.freqs = freqs + self.verbose = verbose + self.subject = _check_subject(None, subject, False) + + # TODO: src_type should rather represent the stc source type + self._src_type = src_type + self._data_ndim = len(dims) + self._vertices = vertices + self._data = data + self._kernel = kernel + self._sens_data = sens_data + self._kernel_removed = False + self._tmin = tmin + self._tstep = tstep + self._times = None + self._update_times() + + def __repr__(self): # noqa: D105 + s = "{} vertices".format((sum(len(v) for v in self.vertices),)) + if self.subject is not None: + s += ", subject : {}".format(self.subject) + s += ", tmin : {} (ms)".format(1e3 * self.tmin) + s += ", tmax : {} (ms)".format(1e3 * self.times[-1]) + s += ", tstep : {} (ms)".format(1e3 * self.tstep) + s += ", data shape : {}".format(self.shape) + return "<{0} | {1}>".format(type(self).__name__, s) + + @property + def vertices(self): + """ The indices of the dipoles in the different source spaces. Can + be an array if there is only one source space (e.g., for volumes). + """ + verts = self._vertices + return [verts] if self._src_type == "volume" else verts + + # TODO: also support loading data + @verbose + def save(self, fname, ftype='h5', verbose=None): + """Save the full SourceTFR to an HDF5 file. + + Parameters + ---------- + fname : string + The file name to write the SourceTFR to, should end in + '-stfr.h5'. + ftype : string + File format to use. Currently, the only allowed values is "h5". + %(verbose_meth)s + """ + # this message looks more informative to me than _check_option(). + if ftype != 'h5': + raise ValueError('{} objects can only be written as HDF5 files.' + .format(self.__class__.__name__, )) + fname = fname if fname.endswith('h5') else '{}-stfr.h5'.format(fname) + write_hdf5(fname, + dict(vertices=self._vertices, data=self.data, + tmin=self.tmin, tstep=self.tstep, + subject=self.subject, src_type=self._src_type), + title='mnepython', overwrite=True) + + @property + def sfreq(self): + """Sample rate of the data.""" + return 1. / self.tstep + + def _remove_kernel_sens_data_(self): + """Remove kernel and sensor space data and compute self._data.""" + if self._kernel is not None or self._sens_data is not None: + self._kernel_removed = True + self._data = np.tensordot(self._kernel, self._sens_data, + axes=([-1], [0])) + self._kernel = None + self._sens_data = None + + def plot(self, fmin=None, fmax=None, epoch=0, **plot_params): + """Plot SourceTFR. + + Plots selected frequencies, using mne.viz.plot_source_estimates, + mne.viz.plot_vector_source_estimates, or + mne.viz.plot_volume_source_estimates, depending on the SourceEstimate + type, from which the SourceTFR was created. + All included frequencies from fmin to fmax will be averaged into one + plot. + + Parameters + ---------- + fmin : float | None + The lowest frequency to include. If None, the lowest frequency + in stfr.freqs is used. + fmax : float | None + The highest frequency to include. If None, the highest frequency + in stfr.freqs is used. + epoch : int, default 0 + If the stfr object contains an "epochs" dimension, only the epoch + index defined in epoch will be plotted. Else will be ignored. + **plot_params : + Additional parameters passed to the respective plotting function. + + Returns + ------- + figure : instance of surfer.Brain | matplotlib.figure.Figure + An instance of :class:`surfer.Brain` from PySurfer or + matplotlib figure. + + See Also + -------- + mne.viz.plot_source_estimates + mne.viz.plot_vector_source_estimates + mne.viz.plot_volume_source_estimates + """ + freq_idx = _freq_mask(self.freqs, self.sfreq, fmin, fmax) + # FIXME: sum over average? sum is easier to interprete + # but will result in bad color scalings + data_cropped = np.mean(self.data[..., freq_idx, :], axis=-2) + if "epochs" in self.dims: + data_cropped = data_cropped[..., epoch, :] + + if self._src_type == "volume": + # use the magnitude only if it's a VolVectorSourceEstimate + # (see _BaseVectorSourceEstimate.plot) + if "orientations" in self.dims: + data_cropped = np.linalg.norm(data_cropped, axis=1) + brain = plot_volume_source_estimates( + VolSourceEstimate(data_cropped, self.vertices, self.tmin, + self.tstep, self.subject), **plot_params) + elif "orientations" in self.dims: + brain = plot_vector_source_estimates( + VectorSourceEstimate(data_cropped, self.vertices, self.tmin, + self.tstep, self.subject), **plot_params) + else: + brain = plot_source_estimates( + SourceEstimate(data_cropped, self.vertices, self.tmin, + self.tstep, self.subject), **plot_params) + + return brain + + def crop(self, tmin=None, tmax=None): + """Restrict SourceTFR to a time interval. + + Parameters + ---------- + tmin : float | None + The first time point in seconds. If None the first present is used. + tmax : float | None + The last time point in seconds. If None the last present is used. + """ + mask = _time_mask(self.times, tmin, tmax, sfreq=self.sfreq) + self.tmin = self.times[np.where(mask)[0][0]] + if self._kernel is not None and self._sens_data is not None: + self._sens_data = self._sens_data[..., mask] + else: + self.data = self.data[..., mask] + + return self # return self for chaining methods + + @verbose + def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=1, + verbose=None): + """Resample data. + + Parameters + ---------- + sfreq : float + New sample rate to use. + npad : int | str + Amount to pad the start and end of the data. + Can also be "auto" to use a padding that will result in + a power-of-two size (can be much faster). + window : string or tuple + Window to use in resampling. See scipy.signal.resample. + %(n_jobs)s + %(verbose_meth)s + + Notes + ----- + For some data, it may be more accurate to use npad=0 to reduce + artifacts. This is dataset dependent -- check your data! + + Note that the sample rate of the original data is inferred from tstep. + """ + # resampling in sensor instead of source space gives a somewhat + # different result, so we don't allow it + self._remove_kernel_sens_data_() + + o_sfreq = 1.0 / self.tstep + self.data = resample(self.data, sfreq, o_sfreq, npad, n_jobs=n_jobs) + + # adjust indirectly affected variables + self.tstep = 1.0 / sfreq + return self + + @property + def data(self): + """Create the SourceTFR data field. + + Parameters + ---------- + %(verbose_meth)s + + Returns + ------- + data : array + The source level time-frequency transformed data. + """ + if self._data is None: + # compute the solution the first time the data is accessed and + # remove the kernel and sensor data + self._remove_kernel_sens_data_() + + return self._data + + @data.setter + def data(self, value): + value = np.asarray(value) + if self._data is not None and value.ndim != self._data.ndim: + raise ValueError('Data array should have {} dimensions.' + .format(self._data.ndim)) + + # vertices can be a single number, so cast to ndarray + if isinstance(self._vertices, list): + n_verts = sum([len(v) for v in self._vertices]) + elif isinstance(self._vertices, np.ndarray): + n_verts = len(self._vertices) + else: + raise ValueError('Vertices must be a list or numpy array') + + if value.shape[0] != n_verts: + raise ValueError('The first dimension of the data array must ' + 'match the number of vertices ({0} != {1})' + .format(value.shape[0], n_verts)) + + self._data = value + self._update_times() + + @property + def shape(self): + """Shape of the data.""" + if self._data is None: + return (self._kernel.shape[0],) + self._sens_data.shape[1:] + + else: + return self._data.shape + + @property + def tmin(self): + """The first timestamp.""" + return self._tmin + + @tmin.setter + def tmin(self, value): + self._tmin = float(value) + self._update_times() + + @property + def tstep(self): + """The change in time between two consecutive samples (1 / sfreq).""" + return self._tstep + + @tstep.setter + def tstep(self, value): + if value <= 0: + raise ValueError('.tstep must be greater than 0.') + self._tstep = float(value) + self._update_times() + + @property + def times(self): + """A timestamp for each sample.""" + return self._times + + @times.setter + def times(self, value): + raise RuntimeError('You cannot write to the .times attribute directly.' + ' This property automatically updates whenever ' + '.tmin, .tstep or .data changes.') + + def _update_times(self): + """Update the times attribute after changing tmin, tmax, or tstep.""" + self._times = self.tmin + (self.tstep * np.arange(self.shape[-1])) + self._times.flags.writeable = False + + def copy(self): + """Return copy of SourceTFR instance.""" + return copy.deepcopy(self) + + +def _prepare_data(data, vertices, dims): + """Check SourceTFR data and construct according data fields. + + Parameters + ---------- + data : array | tuple, shape (2,) + Time-frequency transformed data in source space. Can be either + the full data array, or a tuple of two arrays (kernel, sens_data), + where the full data is equal to np.dot(kernel, sens_data). + vertices : array | list of array + Vertex numbers corresponding to the data. + dims : tuple, default ("dipoles", "freqs", "times") + The dimension names of the data, where each element of the tuple + corresponds to one axis of the data field. Allowed values are: + ("dipoles", "freqs", "times"), ("dipoles", "epochs", "freqs", + "times"), ("dipoles", "orientations", "freqs", "times"), ("dipoles", + "orientations", "epochs", "freqs", "times"). + + Returns + ------- + data : array | None + The source level time-frequency transformed data. If an array was + passed as the data argument, data will be an array. Else None. + kernel : array | None + The imaging kernel to construct the source level time-frequency + transformed data. If a tuple was passed as the data argument, kernel + will be an array. Else None. + sens_data : array | None + The sensor level data to construct the source level time-frequency + transformed data. If a tuple was passed as the data argument + sens_data will be an array. Else None. + vertices : array | None + Flattened vertex numbers corresponding to the data. + """ + kernel, sens_data = None, None + if isinstance(data, tuple): + if len(data) != 2: + raise ValueError('If data is a tuple it has to be length 2') + kernel, sens_data = data + data = None + if kernel.shape[-1] != sens_data.shape[0]: + raise ValueError('The last kernel dimension and the first data ' + 'dimension must be of equal size. Got {0} and ' + '{1} instead.' + .format(kernel.shape[-1], sens_data.shape[0])) + if sens_data.ndim != len(dims): + raise ValueError('The sensor data must have {0} dimensions, ' + 'got {1}'.format(len(dims), sens_data.ndim, )) + # TODO: Make sure this is supported + if 'orientations' in dims: + raise NotImplementedError('Multiple orientations are not ' + 'supported for data=(kernel, sens_data)') + + vertices, n_src = _prepare_vertices(vertices) + + # safeguard the user against doing something silly + if data is not None: + _check_data_shape(data, dims, n_src) + + return data, kernel, sens_data, vertices + + +def _prepare_vertices(vertices): + """Check the vertices and return flattened vertices and their length.""" + if isinstance(vertices, list): + vertices = [np.asarray(v, int) for v in vertices] + if any(np.any(np.diff(v.astype(int)) <= 0) for v in vertices): + raise ValueError('Vertices must be ordered in increasing ' + 'order.') + n_src = sum([len(v) for v in vertices]) + if len(vertices) == 1: + vertices = vertices[0] + + elif isinstance(vertices, np.ndarray): + n_src = len(vertices) + return vertices, n_src + + +def _check_data_shape(data, dims, n_src): + """Check SourceTFR data according to data and vertices dimensions.""" + if data.shape[0] != n_src: + raise ValueError('Number of vertices ({0}) and stfr.shape[0] ' + '({1}) must match'.format(n_src, + data.shape[0])) + if data.ndim != len(dims): + raise ValueError('Data (shape {0}) must have {1} dimensions ' + 'for SourceTFR with dims={2}' + .format(data.shape, len(dims), + dims)) + + if "orientations" in dims and data.shape[1] != 3: + raise ValueError('If multiple orientations are defined, ' + 'stfr.shape[1] must be 3. Got ' + 'shape[1] == {}'.format(data.shape[1])) diff --git a/mne/tests/test_source_tfr.py b/mne/tests/test_source_tfr.py new file mode 100644 index 00000000000..6e03b5e4072 --- /dev/null +++ b/mne/tests/test_source_tfr.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- +# +# Authors: Dirk Gütlin +# Joan Massich +# +# License: BSD (3-clause) + +from copy import deepcopy +import os.path as op + +import numpy as np +from numpy.testing import (assert_array_equal, + assert_allclose, assert_equal) +import pytest +from mne.utils import _TempDir, requires_h5py, run_tests_if_main +from mne.source_tfr import SourceTFR + +rnd = np.random.RandomState(23) + + +@pytest.fixture(scope="module") +def fake_stfr(): + """Create a fake SourceTFR object for testing.""" + verts = [np.arange(10), np.arange(90)] + return SourceTFR(rnd.rand(100, 20, 10), verts, 0, 1e-1, 'foo') + + +@pytest.fixture(scope="module") +def fake_kernel_stfr(): + """Create a fake kernel SourceTFR object for testing.""" + kernel = rnd.rand(100, 40) + sens_data = rnd.rand(40, 20, 10) + verts = [np.arange(10), np.arange(90)] + return SourceTFR((kernel, sens_data), verts, 0, 1e-1, 'foo') + + +def test_stfr_kernel_equality(fake_stfr, fake_kernel_stfr): + """Test if kernelized SourceTFR produce correct data.""" + # compare kernelized and normal data + kernel = rnd.rand(100, 40) + sens_data = rnd.rand(40, 10, 30) + verts = [np.arange(10), np.arange(90)] + data = np.tensordot(kernel, sens_data, axes=([-1], [0])) + tmin = 0 + tstep = 1e-3 + + kernel_stfr = SourceTFR((kernel, sens_data), verts, tmin, tstep) + full_stfr = SourceTFR(data, verts, tmin, tstep) + + # check if data is in correct shape + assert kernel_stfr.shape == (100, 10, 30) + assert full_stfr.shape == (100, 10, 30) + assert kernel_stfr.data.shape == (100, 10, 30) + assert full_stfr.data.shape == (100, 10, 30) + assert_allclose(kernel_stfr.data, full_stfr.data) + + stfr = fake_stfr + kernel_stfr = fake_kernel_stfr + + # alternatively with the fake data + assert_equal(stfr.shape, kernel_stfr.shape) + assert_array_equal(stfr.data.shape, kernel_stfr.data.shape) + + +def test_stfr_attributes(fake_stfr): + """Test stfr attributes.""" + stfr = fake_stfr.copy() + + n_times = len(stfr.times) + assert_equal(stfr._data.shape[-1], n_times) + assert_array_equal(stfr.times, stfr.tmin + np.arange(n_times) * stfr.tstep) + + assert_allclose(stfr.times, + [0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) + + def attempt_times_mutation(stfr): + stfr.times -= 1 + + def attempt_assignment(stfr, attr, val): + setattr(stfr, attr, val) + + # .times is read-only + with pytest.raises(RuntimeError, + match="cannot write to the .times attribute directly"): + attempt_times_mutation(stfr) + with pytest.raises(RuntimeError, + match="cannot write to the .times attribute directly"): + attempt_assignment(stfr, "times", [1]) + + # Changing .tmin or .tstep re-computes .times + stfr.tmin = 1 + assert type(stfr.tmin) == float + assert_allclose(stfr.times, + [1., 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]) + + stfr.tstep = 1 + assert (type(stfr.tstep) == float) + assert_allclose(stfr.times, + [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + + # tstep <= 0 is not allowed + with pytest.raises(ValueError, match="must be greater than 0"): + attempt_assignment(stfr, 'tstep', 0) + with pytest.raises(ValueError, match="must be greater than 0"): + attempt_assignment(stfr, 'tstep', -1) + + # Changing .data re-computes .times + stfr.data = rnd.rand(100, 20, 5) + assert_allclose(stfr.times, [1., 2., 3., 4., 5.]) + + # .data must match the number of vertices + with pytest.raises(ValueError, match="must match the number of vertices"): + attempt_assignment(stfr, "data", [[[1]]]) + + # .data much match number of dimensions + assign = [None, np.arange(100), [np.arange(100)], [[[np.arange(100)]]]] + for val in assign: + with pytest.raises(ValueError, + match="Data.*? should have.*? dimensions"): + attempt_assignment(stfr, "data", val) + + # .shape attribute must also work when ._data is None + stfr._kernel = np.zeros((2, 2)) + stfr._sens_data = np.zeros((2, 5, 3)) + stfr._data = None + assert_equal(stfr.shape, (2, 5, 3)) + + # bad size of data + stfr = fake_stfr + data = stfr.data[:, :, np.newaxis, :] + with pytest.raises(ValueError, match='3 dimensions for SourceTFR'): + SourceTFR(data, stfr.vertices) + stfr = SourceTFR(data[:, :, :, 0], stfr.vertices, 0, 1) + assert stfr.data.shape == (data.shape[0], data.shape[1], 1) + + +@requires_h5py +def test_io_stfr_h5(fake_stfr, fake_kernel_stfr): + """Test IO for stfr files using HDF5.""" + for stfr in [fake_stfr, fake_kernel_stfr]: + tempdir = _TempDir() + with pytest.raises(ValueError, match="can only be written as HDF5"): + stfr.save(op.join(tempdir, 'tmp'), ftype='foo') + out_name = op.join(tempdir, 'tempfile') + stfr.save(out_name, ftype='h5') + stfr.save(out_name, ftype='h5') # test overwrite + # TODO: no read_source_tfr yet + + +def test_stfr_resample(fake_stfr, fake_kernel_stfr): + """Test sftr.resample().""" + stfr_ = fake_stfr + kernel_stfr_ = fake_kernel_stfr + + for stfr in [stfr_, kernel_stfr_]: + stfr_new = deepcopy(stfr) + o_sfreq = 1.0 / stfr.tstep + # note that using no padding for this stfr reduces edge ringing... + stfr_new.resample(2 * o_sfreq, npad=0) + assert stfr_new.data.shape[-1] == 2 * stfr.data.shape[-1] + assert stfr_new.tstep == stfr.tstep / 2 + stfr_new.resample(o_sfreq, npad=0) + assert stfr_new.data.shape[-1] == stfr.data.shape[-1] + assert stfr_new.tstep == stfr.tstep + assert_allclose(stfr_new.data, stfr.data, 5) + + +def test_stfr_crop(fake_stfr, fake_kernel_stfr): + """Test cropping of SourceTFR data.""" + stfr = fake_stfr + kernel_stfr = fake_kernel_stfr + + for inst in [stfr, kernel_stfr]: + copy_1 = inst.copy() + assert_allclose(copy_1.crop(tmax=0.8).data, inst.data[:, :, :9]) + # FIXME: cropping like this does not work for kernelized stfr/stc + # assert_allclose(copy_1.times, inst.times[:9]) + + copy_2 = inst.copy() + assert_allclose(copy_2.crop(tmin=0.2).data, inst.data[:, :, 2:]) + assert_allclose(copy_2.times, inst.times[2:]) + + +def test_invalid_params(): + """Test invalid SourceTFR parameters.""" + data = rnd.rand(40, 10, 20) + verts = [np.arange(10), np.arange(30)] + tmin = 0 + tstep = 1e-3 + + with pytest.raises(TypeError, match="vertices must be an instance of " + "ndarray or list"): + SourceTFR(data, {"1": 1, "2": 2}, tmin, tstep) + + with pytest.raises(ValueError, + match='data.*? tuple .*? has to be length 2'): + SourceTFR((data, (42, 42), (42, 42)), verts, tmin, tstep) + + with pytest.raises(ValueError, match='last kernel.*? first data dimension' + ' must be of equal size'): + SourceTFR((np.zeros((42, 42)), data), verts, tmin, tstep) + + with pytest.raises(ValueError, + match='sensor data must have .*? dimensions'): + SourceTFR((np.zeros((2, 20)), np.zeros((20, 3))), verts, tmin, tstep) + + with pytest.raises(ValueError, + match='Vertices must be ordered in increasing order.'): + SourceTFR(data, [np.zeros(10), np.zeros(90)], tmin, tstep) + + with pytest.raises(ValueError, + match='vertices .*? and stfr.shape.*? must match'): + SourceTFR(np.ones([42, 10, 20]), verts, tmin, tstep) + + with pytest.raises(ValueError, + match='(shape .*?) must have .*? dimensions'): + SourceTFR(np.ones([40, 10, 20, 10]), verts, tmin, tstep) + + with pytest.raises(ValueError, + match='multiple orientations.*? must be 3'): + SourceTFR(np.ones([40, 10, 20, 10]), verts, tmin, tstep, + dims=("dipoles", "orientations", "freqs", "times")) + + with pytest.raises(ValueError, + match="Invalid value for the 'dims' parameter"): + SourceTFR(data, verts, tmin, tstep, dims=("dipoles", "nonsense")) + + with pytest.raises(ValueError, + match="Invalid value for the 'method' parameter"): + SourceTFR(data, verts, tmin, tstep, method="invalid") + + with pytest.raises(ValueError, + match="Invalid value for the 'src_type' parameter"): + SourceTFR(data, verts, tmin, tstep, src_type="invalid") + + +run_tests_if_main() diff --git a/mne/time_frequency/_stockwell.py b/mne/time_frequency/_stockwell.py index 660272bdbf0..c59a9492050 100644 --- a/mne/time_frequency/_stockwell.py +++ b/mne/time_frequency/_stockwell.py @@ -3,7 +3,7 @@ # # License : BSD 3-clause -from copy import deepcopy +from inspect import isgenerator import math import numpy as np from scipy import fftpack @@ -12,7 +12,7 @@ from ..io.pick import _pick_data_channels, pick_info from ..utils import verbose, warn, fill_doc from ..parallel import parallel_func, check_n_jobs -from .tfr import AverageTFR, _get_data +from .tfr import _get_data, _check_stfr_list_elem, _assign_tfr_class def _check_input_st(x_in, n_fft): @@ -72,6 +72,19 @@ def _st(x, start_f, windows): return ST +def _select_st_freqs(fmin, fmax, sfreq, n_fft): + """Select stockwell freqs based on input freqs and window length.""" + freqs = fftpack.fftfreq(n_fft, 1. / sfreq) + if fmin is None: + fmin = freqs[freqs > 0][0] + if fmax is None: + fmax = freqs.max() + start_f = np.abs(freqs - fmin).argmin() + stop_f = np.abs(freqs - fmax).argmin() + freqs = freqs[start_f:stop_f] + return freqs, start_f, stop_f + + def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W): """Aux function.""" n_samp = x.shape[-1] @@ -95,9 +108,88 @@ def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W): itc[i_f] = np.abs(np.mean(TFR, axis=0)) TFR_abs *= TFR_abs psd[i_f] = np.mean(TFR_abs, axis=0) + return psd, itc +def _tfr_list_stockwell(inst, fmin, fmax, n_fft, width, decim, return_itc, + n_jobs): + """Perform stockwell transform on stc lists/generator objects.""" + from ..source_estimate import _BaseSourceEstimate + + for ep_idx, obj in enumerate(inst): + + if not isinstance(obj, _BaseSourceEstimate): + raise TypeError("List or generator input must consist of " + "SourceEstimate objects. Got {}." + .format(type(inst))) + + # load the data. Set return_itc=False to omit an Error + data, kernel = _get_data(obj, return_itc=False, fill_dims=False) + + data, n_fft_, zero_pad = _check_input_st(data, n_fft) + + if ep_idx == 0: + # initiate stuff for the first input + sfreq = obj.sfreq + type_ref = type(obj) + tmin_ref = obj._tmin + + n_samp = data.shape[-1] + n_out = (n_samp - zero_pad) + n_out = n_out // decim + bool(n_out % decim) + n_dipoles = len(kernel) if kernel is not None else len(data) + + freqs, start_f, stop_f = _select_st_freqs(fmin, fmax, sfreq, + n_fft_) + W = _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width) + + psd = np.zeros((n_dipoles, len(W), n_out)) + itc = np.zeros_like(psd, dtype=np.complex) if return_itc else None + + else: + # make sure all elements got the same properties as the first one + _check_stfr_list_elem(obj, type_ref, sfreq, tmin_ref) + + X = fftpack.fft(data) + XX = np.concatenate([X, X], axis=-1) + + for i_f, window in enumerate(W): + f = start_f + i_f + ST = fftpack.ifft(XX[:, f:f + n_samp] * window) + if zero_pad > 0: + TFR = ST[:, :-zero_pad:decim] + else: + TFR = ST[:, ::decim] + + if kernel is not None: + # get the full source time series from kernel and tfr + TFR = np.tensordot(kernel, TFR, [-1, 0]) + + # transform complex values + TFR_abs = np.abs(TFR) + TFR_abs[TFR_abs == 0] = 1. + if return_itc: + TFR /= TFR_abs + itc[:, i_f, :] += TFR + TFR_abs *= TFR_abs + psd[:, i_f, :] += TFR_abs + + # divide summed epochs to get the average + psd /= ep_idx + 1 + + if return_itc: + # average the epochs + itc /= ep_idx + 1 + # calculate the abs for each taper + for i_f, window in enumerate(W): + itc[:, i_f, :] = np.abs(itc[:, i_f, :]) + itc = itc.real + + # one list object is passed for type references etc. + return psd, itc, freqs, obj + + @fill_doc def tfr_array_stockwell(data, sfreq, fmin=None, fmax=None, n_fft=None, width=1.0, decim=1, return_itc=False, n_jobs=1): @@ -172,15 +264,7 @@ def tfr_array_stockwell(data, sfreq, fmin=None, fmax=None, n_fft=None, n_out = data.shape[2] // decim + bool(data.shape[2] % decim) data, n_fft_, zero_pad = _check_input_st(data, n_fft) - freqs = fftpack.fftfreq(n_fft_, 1. / sfreq) - if fmin is None: - fmin = freqs[freqs > 0][0] - if fmax is None: - fmax = freqs.max() - - start_f = np.abs(freqs - fmin).argmin() - stop_f = np.abs(freqs - fmax).argmin() - freqs = freqs[start_f:stop_f] + freqs, start_f, stop_f = _select_st_freqs(fmin, fmax, sfreq, n_fft_) W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width) n_freq = stop_f - start_f @@ -207,8 +291,9 @@ def tfr_stockwell(inst, fmin=None, fmax=None, n_fft=None, Parameters ---------- - inst : Epochs | Evoked - The epochs or evoked object. + inst : Epochs | Evoked | SourceEstimate | list of SourceEstimate + The object to be computed. Can be Epochs, Evoked, any type of + SourceEstimate, or a list of multiple SourceEstimates of the same type. fmin : None, float The minimum frequency to include. If None defaults to the minimum fft frequency greater than zero. @@ -231,9 +316,9 @@ def tfr_stockwell(inst, fmin=None, fmax=None, n_fft=None, Returns ------- - power : AverageTFR + power : AverageTFR | SourceTFR The averaged power. - itc : AverageTFR + itc : AverageTFR | SourceTFR The intertrial coherence. Only returned if return_itc is True. See Also @@ -248,21 +333,40 @@ def tfr_stockwell(inst, fmin=None, fmax=None, n_fft=None, ----- .. versionadded:: 0.9.0 """ + from ..source_estimate import _BaseSourceEstimate # verbose dec is used b/c subfunctions are verbose - data = _get_data(inst, return_itc) - picks = _pick_data_channels(inst.info) - info = pick_info(inst.info, picks) - data = data[:, picks, :] + n_jobs = check_n_jobs(n_jobs) - power, itc, freqs = tfr_array_stockwell(data, sfreq=info['sfreq'], - fmin=fmin, fmax=fmax, n_fft=n_fft, - width=width, decim=decim, - return_itc=return_itc, - n_jobs=n_jobs) + + info = None + nave = None + if isinstance(inst, list) or isgenerator(inst): + + power, itc, freqs, inst = \ + _tfr_list_stockwell(inst, fmin, fmax, n_fft, width, decim, + return_itc, n_jobs) + + else: + data, _ = _get_data(inst, return_itc) + if isinstance(inst, _BaseSourceEstimate): + sfreq = inst.sfreq + else: + nave = len(data) + picks = _pick_data_channels(inst.info) + data = data[:, picks, :] + info = pick_info(inst.info, picks) + sfreq = info['sfreq'] + + power, itc, freqs = \ + tfr_array_stockwell(data, sfreq=sfreq, fmin=fmin, fmax=fmax, + n_fft=n_fft, width=width, decim=decim, + return_itc=return_itc, n_jobs=n_jobs) + times = inst.times[::decim].copy() - nave = len(data) - out = AverageTFR(info, power, times, freqs, nave, method='stockwell-power') + out = _assign_tfr_class(power, inst, info, freqs, times, average=True, + nave=nave, method='stockwell-power') if return_itc: - out = (out, AverageTFR(deepcopy(info), itc, times.copy(), - freqs.copy(), nave, method='stockwell-itc')) + out = (out, _assign_tfr_class(itc, inst, info, freqs, times, + average=True, nave=nave, + method='stockwell-itc')) return out diff --git a/mne/time_frequency/tests/test_stockwell.py b/mne/time_frequency/tests/test_stockwell.py index d672cfd81ed..70dd19b514e 100644 --- a/mne/time_frequency/tests/test_stockwell.py +++ b/mne/time_frequency/tests/test_stockwell.py @@ -20,6 +20,8 @@ _st_power_itc) from mne.time_frequency.tfr import AverageTFR +from mne.time_frequency.tests.test_tfr import (_create_ref_data, + _check_stc_list_input) from mne.utils import run_tests_if_main base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data') @@ -132,5 +134,43 @@ def test_stockwell_api(): assert (np.log(power.data.max()) * 20 <= 0.0) assert (np.log(power.data.max()) * 20 <= 0.0) + # test list input + _check_stc_list_input(tfr_stockwell) + + +@pytest.mark.filterwarnings('ignore:.*The unit .*? has changed from NA to V.') +@pytest.mark.filterwarnings('ignore:.*Applying zero padding.') +@pytest.mark.parametrize('return_itc', [True, False]) +@pytest.mark.parametrize('kernel', [True, False]) +def test_stfr_stockwell(return_itc, kernel): + """Test if SourceTFRs are computed in the same way as sensor space TFRs.""" + fmin = 10 + fmax = 16 + + epochs_ref, stc_list, stc_gen, evoked_ref, stc_single =\ + _create_ref_data(kernel) + + ep_tfr = tfr_stockwell(epochs_ref, fmin, fmax, return_itc=return_itc) + list_stfrs = tfr_stockwell(stc_list, fmin, fmax, return_itc=return_itc) + gen_stfrs = tfr_stockwell(stc_gen, fmin, fmax, return_itc=return_itc) + + if not return_itc: + # make sure we can loop over variables for both return_itc options + ep_tfr, list_stfrs, gen_stfrs = [ep_tfr], [list_stfrs], [gen_stfrs] + + # compare power as well as itc data + for epoch_tfr, list_stfr, gen_stfr in zip(ep_tfr, list_stfrs, gen_stfrs): + assert_allclose(list_stfr.data, epoch_tfr.data) + assert_allclose(gen_stfr.data, epoch_tfr.data) + + assert list_stfr.method == epoch_tfr.method + assert gen_stfr.method == epoch_tfr.method + + evoked_tfr = tfr_stockwell(evoked_ref, fmin, fmax, return_itc=False) + single_stfr = tfr_stockwell(stc_single, fmin, fmax, return_itc=False) + + assert_allclose(evoked_tfr.data, single_stfr.data) + assert evoked_tfr.method == single_stfr.method + run_tests_if_main() diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 2a47e7d5a31..7c9d538f7c4 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -3,12 +3,13 @@ import numpy as np from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_equal) + assert_equal, assert_allclose) import pytest import matplotlib.pyplot as plt import mne -from mne import Epochs, read_events, pick_types, create_info, EpochsArray +from mne import (Epochs, read_events, pick_types, find_events, create_info, + EpochsArray, SourceEstimate, VolSourceEstimate) from mne.io import read_raw_fif from mne.utils import (_TempDir, run_tests_if_main, requires_h5py, requires_pandas, grand_average) @@ -19,12 +20,26 @@ from mne.time_frequency import tfr_array_multitaper, tfr_array_morlet from mne.viz.utils import _fake_click from mne.tests.test_epochs import assert_metadata_equal +from mne.datasets import testing +from mne.label import read_label +from mne.minimum_norm.inverse import (read_inverse_operator, + apply_inverse_epochs) +from mne.minimum_norm.time_frequency import source_induced_power + data_path = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data') raw_fname = op.join(data_path, 'test_raw.fif') event_fname = op.join(data_path, 'test-eve.fif') raw_ctf_fname = op.join(data_path, 'test_ctf_raw.fif') +testing_path = testing.data_path(download=False) +stc_inv_fname = op.join(testing_path, 'MEG', 'sample', + 'sample_audvis_trunc-meg-eeg-oct-6-meg-inv.fif') +stc_raw_fname = op.join(testing_path, 'MEG', 'sample', + 'sample_audvis_trunc_raw.fif') +stc_label_fname = op.join(testing_path, 'MEG', 'sample', + 'labels', 'Aud-lh.label') + def test_tfr_ctf(): """Test that TFRs can be calculated on CTF data.""" @@ -254,6 +269,9 @@ def test_time_frequency(): psd = cwt(data[0], [Ws[0][:-1]], use_fft=False, mode='full') assert_equal(psd.shape, (2, 1, 420)) + # check errors for contradicting SourceEstimate input + _check_stc_list_input(tfr_morlet, freqs=[10, 12], n_cycles=1) + def test_dpsswavelet(): """Test DPSS tapers.""" @@ -753,4 +771,176 @@ def test_getitem_epochsTFR(): # Test that current state is maintained assert_array_equal(power.next(), power.data[ind + 1]) + +def _prepare_epochs(n_epochs): + """Load data and create an Epochs object from it.""" + raw = read_raw_fif(stc_raw_fname) + tmin, tmax = -0.2, 0.5 + events = find_events(raw, stim_channel='STI 014') + sel_events = events[:n_epochs] + epochs = Epochs(raw, sel_events, tmin=tmin, tmax=tmax, preload=True) + + return epochs + + +def _create_ref_data(return_kernel=False): + """Create different data types that should produce equal TFRs.""" + def stc_generator(kernel, sens_data, data, verts, tmin, tstep): + for i in range(len(data)): + tmp_dat = (kernel[i], sens_data) if return_kernel else data[i] + yield SourceEstimate(tmp_dat, verts, tmin, tstep) + + def create_stc_list(kernel, sens_data, data, verts, tmin, tstep): + if return_kernel: + stcs = [SourceEstimate((kernel[i], sens_data), verts, tmin, tstep) + for i in range(len(data))] + else: + stcs = [SourceEstimate(data[i], verts, tmin, tstep) + for i in range(data.shape[0])] + return stcs + + sens_data = np.random.rand(21, 211) + kernel = np.random.rand(4, 33, 21) + data = np.tensordot(kernel, sens_data, axes=(-1, 0)) + chans = np.array([i for i in range(data.shape[1])]) + ch_names = list(chans.astype(str)) + ch_types = {} + for name in ch_names: + ch_types[name] = 'eeg' + verts = [chans, np.array([])] + sfreq = 300 + tstep = 1. / sfreq + tmin = 0 + + epochs_ref = EpochsArray(data, create_info(ch_names, sfreq), tmin=0) + epochs_ref.set_channel_types(ch_types) + stc_list = create_stc_list(kernel, sens_data, data, verts, tmin, tstep) + stc_gen = stc_generator(kernel, sens_data, data, verts, tmin, tstep) + + evoked_ref = epochs_ref.average(picks='all') + if return_kernel: + stc_single = SourceEstimate((np.mean(kernel, axis=0), sens_data), + verts, tmin, tstep) + else: + stc_single = SourceEstimate(evoked_ref.data, verts, tmin, tstep) + + return epochs_ref, stc_list, stc_gen, evoked_ref, stc_single + + +def _check_stc_list_input(func, **kwargs): + """Check if functions raise Errors for invalid stc list input.""" + stc_data = np.ones([3, 64]) + verts = [np.array([1, 2, 3]), np.array([])] + tstep = 1. / 128. + tmin = 0.1 + stc_ref = SourceEstimate(stc_data, verts, tmin, tstep) + stc_1 = VolSourceEstimate(stc_data, verts, tmin, tstep) + stc_2 = SourceEstimate(stc_data, verts, tmin=0.2, tstep=tstep) + stc_3 = SourceEstimate(stc_data, verts, tmin, tstep=1. / 129.) + + with pytest.raises(TypeError, match="must be of the same " + "SourceEstimate type"): + func([stc_ref, stc_1], **kwargs) + + with pytest.raises(ValueError, match="must have the same tmin"): + func([stc_ref, stc_2], **kwargs) + + with pytest.raises(ValueError, match="must have the same sfreq"): + func([stc_ref, stc_3], **kwargs) + + with pytest.raises(TypeError, match="must consist of " + "SourceEstimate objects"): + func([dict(invalid="invalid"), stc_ref], **kwargs) + + +@testing.requires_testing_data +@pytest.mark.parametrize('n_epochs', [1, 3]) +@pytest.mark.parametrize('return_itc', [True, False]) +def test_morlet_induced_power_equivalence(n_epochs, return_itc): + """Test equivalence of tfr_morlet(stc) and source_induced_power.""" + epochs = _prepare_epochs(n_epochs) + inv = read_inverse_operator(stc_inv_fname) + label = read_label(stc_label_fname) + + method = "dSPM" + pick_ori = "normal" + l2 = 1. / 9. + freqs = np.array([10, 12, 14, 16]) + n_cycles = 2 + use_fft = True + decim = 1 + zero_mean = False + + stcs = apply_inverse_epochs(epochs, inv, lambda2=l2, method=method, + pick_ori=pick_ori, label=label, prepared=False, + delayed=False) + stfr = tfr_morlet(stcs, freqs=freqs, n_cycles=n_cycles, use_fft=use_fft, + decim=decim, zero_mean=zero_mean, return_itc=return_itc, + output='power', average=True) + stfr_ref, itc_ref = \ + source_induced_power(epochs, inv, lambda2=l2, method=method, + pick_ori=pick_ori, label=label, prepared=False, + freqs=freqs, n_cycles=n_cycles, use_fft=use_fft, + decim=decim, zero_mean=zero_mean, pca=False) + + if return_itc: + assert_allclose(np.reshape(stfr[0].data, stfr_ref.shape), stfr_ref) + assert_allclose(np.reshape(stfr[1].data, itc_ref.shape), itc_ref) + else: + assert_allclose(np.reshape(stfr.data, stfr_ref.shape), stfr_ref) + + +@pytest.mark.filterwarnings('ignore:.*The unit .*? has changed from NA to V.') +@pytest.mark.parametrize('tfr_func', [tfr_morlet, tfr_multitaper]) +@pytest.mark.parametrize('return_itc, average', + [[False, False], + [False, True], + [True, True]]) +@pytest.mark.parametrize('kernel', [True, False]) +def test_stfr_equivalence(tfr_func, return_itc, average, kernel): + """Test if SourceTFRs are computed in the same way as sensor space TFRs.""" + n_cycles = 3 + use_fft = True + decim = 1 + freqs = [10, 12, 14, 16] + + epochs_ref, stc_list, stc_gen, evoked_ref, stc_single =\ + _create_ref_data(kernel) + + ep_tfrs = tfr_func(epochs_ref, freqs=freqs, n_cycles=n_cycles, + use_fft=use_fft, decim=decim, return_itc=return_itc, + average=average) + list_stfrs = tfr_func(stc_list, freqs=freqs, n_cycles=n_cycles, + use_fft=use_fft, decim=decim, return_itc=return_itc, + average=average) + gen_stfrs = tfr_func(stc_gen, freqs=freqs, n_cycles=n_cycles, + use_fft=use_fft, decim=decim, return_itc=return_itc, + average=average) + + # if average is False, stfr shapes need to be switched to epochs shape + trans = (0, 1, 2) if average else (1, 0, 2, 3) + + if not return_itc: + # make sure we can loop over variables for both return_itc options + ep_tfrs, list_stfrs, gen_stfrs = [ep_tfrs], [list_stfrs], [gen_stfrs] + + # compare power as well as itc data + for epoch_tfr, list_stfr, gen_stfr in zip(ep_tfrs, list_stfrs, gen_stfrs): + assert_allclose(list_stfr.data.transpose(trans), epoch_tfr.data) + assert_allclose(gen_stfr.data.transpose(trans), epoch_tfr.data) + + assert_equal(list_stfr.method, epoch_tfr.method) + assert_equal(gen_stfr.method, epoch_tfr.method) + + evoked_tfr = tfr_func(evoked_ref, freqs=freqs, n_cycles=n_cycles, + use_fft=use_fft, decim=decim, + return_itc=False, average=average) + single_stfr = tfr_func(stc_single, freqs=freqs, n_cycles=n_cycles, + use_fft=use_fft, decim=decim, + return_itc=False, average=average) + + assert_allclose(single_stfr.data.transpose(trans), evoked_tfr.data) + assert_equal(single_stfr.method, evoked_tfr.method) + + run_tests_if_main() diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 8883584840c..133e7b09b4d 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -11,6 +11,7 @@ from copy import deepcopy from functools import partial +from inspect import isgenerator from math import sqrt import numpy as np @@ -343,14 +344,8 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', _check_tfr_param(freqs, sfreq, method, zero_mean, n_cycles, time_bandwidth, use_fft, decim, output) - # Setup wavelet - if method == 'morlet': - W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean) - Ws = [W] # to have same dimensionality as the 'multitaper' case - - elif method == 'multitaper': - Ws = _make_dpss(sfreq, freqs, n_cycles=n_cycles, - time_bandwidth=time_bandwidth, zero_mean=zero_mean) + Ws = _create_tapers(method, sfreq, freqs, n_cycles, zero_mean, + time_bandwidth) # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: @@ -489,9 +484,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): The decimation slice: e.g. power[:, decim] """ # Set output type - dtype = np.float - if output in ['complex', 'avg_power_itc']: - dtype = np.complex + # avg_power_itc is stored as power + 1i * itc + dtype = np.complex if output in ['complex', 'avg_power_itc'] else np.float # Init outputs decim = _check_decim(decim) @@ -512,21 +506,16 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): # Loop across epochs for epoch_idx, tfr in enumerate(coefs): - # Transform complex values - if output in ['power', 'avg_power']: - tfr = (tfr * tfr.conj()).real # power - elif output == 'phase': - tfr = np.angle(tfr) - elif output == 'avg_power_itc': - tfr_abs = np.abs(tfr) - plf += tfr / tfr_abs # phase - tfr = tfr_abs ** 2 # power - elif output == 'itc': + # Transform itc complex values + if "itc" in output: plf += tfr / np.abs(tfr) # phase - continue # not need to stack anything else than plf + # Transform other tfr complex values + tfr = _transform_complex_values(tfr, output) - # Stack or add - if ('avg_' in output) or ('itc' in output): + # Stack, add, or continue + if output == 'itc': + continue + elif ('avg_' in output) or ('itc' in output): tfrs += tfr else: tfrs[epoch_idx] += tfr @@ -546,6 +535,208 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): return tfrs +def _transform_complex_values(data, output): + + if output in ['power', 'avg_power']: + data = (data * data.conj()).real # power + elif output == 'phase': + data = np.angle(data) + elif output == 'avg_power_itc': + data = np.abs(data) ** 2 # power + + return data + + +def _create_tapers(method, sfreq, freqs, n_cycles, zero_mean, time_bandwidth): + """Create multitaper or morlet wavelets for TFR functions.""" + # Setup wavelet + if method == 'morlet': + # make a list to have the same dims as the 'multitaper' case + Ws = [morlet(sfreq, freqs, n_cycles=n_cycles, + zero_mean=zero_mean)] + + elif method == 'multitaper': + Ws = _make_dpss(sfreq, freqs, n_cycles=n_cycles, + time_bandwidth=time_bandwidth, + zero_mean=zero_mean) + return Ws + + +# FIXME: n_jobs is not needed here, but still needs to be passed since it's a +# tfr_param. This should be fixed +def _tfr_loop_list(list_data, freqs, method='morlet', n_cycles=7.0, + zero_mean=None, time_bandwidth=None, use_fft=True, decim=1, + output='complex', n_jobs=None, verbose=None): + """Compute time-frequency transforms for lists of SourceEstimate types. + + Parameters + ---------- + epoch_data : array of shape (n_epochs, n_channels, n_times) + The epochs. + freqs : array-like of floats, shape (n_freqs) + The frequencies. + method : 'multitaper' | 'morlet', default 'morlet' + The time-frequency method. 'morlet' convolves a Morlet wavelet. + 'multitaper' uses Morlet wavelets windowed with multiple DPSS + multitapers. + n_cycles : float | array of float, default 7.0 + Number of cycles in the Morlet wavelet. Fixed number + or one per frequency. + zero_mean : bool | None, default None + None means True for method='multitaper' and False for method='morlet'. + If True, make sure the wavelets have a mean of zero. + time_bandwidth : float, default None + If None and method=multitaper, will be set to 4.0 (3 tapers). + Time x (Full) Bandwidth product. Only applies if + method == 'multitaper'. The number of good tapers (low-bias) is + chosen automatically based on this to equal floor(time_bandwidth - 1). + use_fft : bool, default True + Use the FFT for convolutions or not. + decim : int | slice, default 1 + To reduce memory usage, decimation factor after time-frequency + decomposition. + If `int`, returns tfr[..., ::decim]. + If `slice`, returns tfr[..., decim]. + + .. note:: + Decimation may create aliasing artifacts, yet decimation + is done after the convolutions. + + output : str, default 'complex' + + * 'complex' : single trial complex. + * 'power' : single trial power. + * 'phase' : single trial phase. + * 'avg_power' : average of single trial power. + * 'itc' : inter-trial coherence. + * 'avg_power_itc' : average of single trial power and inter-trial + coherence across trials. + n_jobs : int | str + Will bee ignored for this function. + %(verbose)s + + Returns + ------- + out : array + Time frequency transform of SourceEstimate. If output is in ['complex', + 'phase', 'power'], then shape of out is (n_dipoles, n_epochs, n_freqs, + n_times), else it is (n_dipoles, n_freqs, n_times). If output is + 'avg_power_itc', the real values code for 'avg_power' and the + imaginary values code for the 'itc': out = avg_power + i * itc + """ + from ..source_estimate import _BaseSourceEstimate + + # Initialize output + decim = _check_decim(decim) + n_freqs = len(freqs) + + # chose whether the arrays should be of complex or float dtype + # avg_power_itc is stored as power + 1i * itc + dtype = np.complex if output in ['complex', 'avg_power_itc'] else np.float + + # loop along the epochs (represented as list elements) + for epoch_idx, inst in enumerate(list_data): + + if not isinstance(inst, _BaseSourceEstimate): + raise TypeError("List or generator input must consist of " + "SourceEstimate objects. Got {}." + .format(type(inst))) + + # get the data array + # omit Error by setting return_itc=False + X, K = _get_data(inst, return_itc=False, fill_dims=False) + + if epoch_idx == 0: # initialize some stuff in the first epoch + + sfreq = inst.sfreq + type_ref = type(inst) + tmin_ref = inst.tmin + + # Check params + freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim = \ + _check_tfr_param(freqs, sfreq, method, zero_mean, n_cycles, + time_bandwidth, use_fft, decim, output) + + # create tapers + Ws = _create_tapers(method, sfreq, freqs, n_cycles, zero_mean, + time_bandwidth) + + # check tapers + if len(Ws[0][0]) > inst.shape[-1]: + raise ValueError('At least one of the wavelets is longer ' + 'than the signal. Use a longer signal or ' + 'shorter wavelets.') + + # initiate tfrs + n_times = X.shape[-1] + if inst._sens_data is not None: + n_verts = K.shape[0] + else: + n_verts = X.shape[0] + + if ('avg_' in output) or ('itc' in output): + tfrs = np.zeros((n_verts, n_freqs, n_times), dtype=dtype) + else: + tfrs = np.zeros((1, n_verts, n_freqs, n_times), dtype=dtype) + + # Inter-trial phase locking is summed up along the taper axis + if 'itc' in output: + plf = np.zeros((len(Ws), n_verts, n_freqs, n_times), + dtype=np.complex) + + else: # for all iterations except the first + + # make sure these elements got the same properties as the first one + _check_stfr_list_elem(inst, type_ref, sfreq, tmin_ref) + + # add a new epoch element to the tfrs array after each epoch + if not (('avg_' in output) or ('itc' in output)): + new_epoch = np.zeros((1, n_verts, n_freqs, n_times), + dtype=dtype) + tfrs = np.concatenate((tfrs, new_epoch), axis=0) + + # loop over the tapers + for W_idx, W in enumerate(Ws): + + tfr = cwt(X, W, use_fft=use_fft, mode='same', decim=decim) + + # compute the full source time series from kernel and tfr + if inst._sens_data is not None: + tfr = np.tensordot(K, tfr, [1, 0]) + + # Transform itc complex values + if "itc" in output: + plf[W_idx] += tfr / np.abs(tfr) # phase + # Transform other tfr complex values + tfr = _transform_complex_values(tfr, output) + + # Stack, add, or continue + if output == 'itc': + continue + elif ('avg_' in output): + tfrs += tfr + else: + tfrs[epoch_idx] += tfr + + # process the itc taper sums with the data + if "itc" in output: + for W_idx in range(len(Ws)): + # Compute inter trial coherence + if output == 'avg_power_itc': + tfrs += 1j * np.abs(plf[W_idx]) + elif output == 'itc': + tfrs += np.abs(plf[W_idx]) + + # Normalization of average metrics + if ('avg_' in output) or ('itc' in output): + tfrs /= epoch_idx + 1 + + # Normalization by number of taper + tfrs /= len(Ws) + + return tfrs, inst + + def cwt(X, Ws, use_fft=True, mode='same', decim=1): """Compute time freq decomposition with continuous wavelet transform. @@ -594,14 +785,10 @@ def cwt(X, Ws, use_fft=True, mode='same', decim=1): def _tfr_aux(method, inst, freqs, decim, return_itc, picks, average, output=None, **tfr_params): - from ..epochs import BaseEpochs + from ..source_estimate import _BaseSourceEstimate + """Help reduce redundancy between tfr_morlet and tfr_multitaper.""" decim = _check_decim(decim) - data = _get_data(inst, return_itc) - info = inst.info - - info, data = _prepare_picks(info, data, picks, axis=1) - del picks if average: if output == 'complex': @@ -616,23 +803,63 @@ def _tfr_aux(method, inst, freqs, decim, return_itc, picks, average, raise ValueError('Inter-trial coherence is not supported' ' with average=False') - out = _compute_tfr(data, freqs, info['sfreq'], method=method, - output=output, decim=decim, **tfr_params) + info = None + if isinstance(inst, list) or isgenerator(inst): + + out, inst = _tfr_loop_list(inst, freqs, method=method, + output=output, decim=decim, **tfr_params) + # nave is not needed for any SourceTFR type + nave = None + else: + data, _ = _get_data(inst, return_itc) + if average: + nave = len(data) + if isinstance(inst, _BaseSourceEstimate): + sfreq = inst.sfreq + else: + info, data = _prepare_picks(inst.info, data, picks, axis=1) + sfreq = inst.info['sfreq'] + del picks + + out = _compute_tfr(data, freqs, sfreq, method=method, + output=output, decim=decim, **tfr_params) + + if average and return_itc: + power, itc = out.real, out.imag + else: + power = out + times = inst.times[decim].copy() + # put the output objects together accordingly if average: + out = _assign_tfr_class(power, inst, info, freqs, times, average, + method='{}-power'.format(method), nave=nave) if return_itc: - power, itc = out.real, out.imag - else: - power = out - nave = len(data) - out = AverageTFR(info, power, times, freqs, nave, - method='%s-power' % method) - if return_itc: - out = (out, AverageTFR(info, itc, times, freqs, nave, - method='%s-itc' % method)) + out = (out, _assign_tfr_class(itc, inst, info, freqs, times, + average, nave=nave, + method='{}-itc'.format(method))) + else: + out = _assign_tfr_class(power, inst, info, freqs, times, average, + method='{}-power'.format(method)) + return out + + +def _assign_tfr_class(data, inst, info, freqs, times, average, method, + nave=None): + """Create different TFR objects, based on wanted type and output.""" + from ..epochs import BaseEpochs + from ..source_estimate import _BaseSourceEstimate + + # create SourceTFR types + if isinstance(inst, _BaseSourceEstimate): + out = _create_stfr(inst, data, freqs, method=method) + + # else create Epochs/AverageTFR + elif average: + out = AverageTFR(info, data, times, freqs, nave, + method=method) else: - power = out if isinstance(inst, BaseEpochs): meta = deepcopy(inst._metadata) evs = deepcopy(inst.events) @@ -640,11 +867,48 @@ def _tfr_aux(method, inst, freqs, decim, return_itc, picks, average, else: # if the input is of class Evoked meta = evs = ev_id = None + out = EpochsTFR(info, data, times, freqs, events=evs, + event_id=ev_id, metadata=meta, + method=method) + return out - out = EpochsTFR(info, power, times, freqs, method='%s-power' % method, - events=evs, event_id=ev_id, metadata=meta) - return out +def _check_stfr_list_elem(inst, type_ref, freq_ref, tmin_ref): + """Check if an stfr list/gen element matches the reference data.""" + if not isinstance(inst, type_ref): + raise TypeError("All computed elements must be of the same " + "SourceEstimate type. Got {} and {}." + .format(type_ref, type(inst))) + if not inst.sfreq == freq_ref: + raise ValueError("All computed elements must have the same " + "sfreq.") + if not inst.tmin == tmin_ref: + raise ValueError("All computed elements must have the same " + "tmin.") + + +def _create_stfr(inst, out, freqs, method): + """Prepare data and create a SourceTFR object from _tfr_aux output.""" + from ..source_estimate import VectorSourceEstimate, VolVectorSourceEstimate + from ..source_tfr import SourceTFR + + if len(out.shape) == 4: # epoched data + dims = ["dipoles", "epochs", "freqs", "times"] + # switch the epoch axis according to wanted dims + out = np.moveaxis(out, source=0, destination=1) + + else: # len 3 non-epoched data + # No need to reshape, only name the axes + dims = ["dipoles", "freqs", "times"] + + if isinstance(inst, (VectorSourceEstimate, VolVectorSourceEstimate)): + # put in the orientation dimension after the dipoles + dims.insert(1, "orientations") + newshape = (out.shape[0] // 3, 3,) + out.shape[1:] + out = np.reshape(out, newshape) + + return SourceTFR(out, inst.vertices, inst.tmin, inst.tstep, freqs, + tuple(dims), method, inst.subject, inst._src_type) @verbose @@ -655,8 +919,9 @@ def tfr_morlet(inst, freqs, n_cycles, use_fft=False, return_itc=True, decim=1, Parameters ---------- - inst : Epochs | Evoked - The epochs or evoked object. + inst : Epochs | Evoked | SourceEstimate | list of SourceEstimate + The object to be computed. Can be Epochs, Evoked, any type of + SourceEstimate, or a list of multiple SourceEstimates of the same type. freqs : ndarray, shape (n_freqs,) The frequencies in Hz. n_cycles : float | ndarray, shape (n_freqs,) @@ -695,9 +960,9 @@ def tfr_morlet(inst, freqs, n_cycles, use_fft=False, return_itc=True, decim=1, Returns ------- - power : AverageTFR | EpochsTFR + power : AverageTFR | EpochsTFR | SourceTFR The averaged or single-trial power. - itc : AverageTFR | EpochsTFR + itc : AverageTFR | EpochsTFR | SourceTFR The inter-trial coherence (ITC). Only returned if return_itc is True. @@ -799,8 +1064,9 @@ def tfr_multitaper(inst, freqs, n_cycles, time_bandwidth=4.0, Parameters ---------- - inst : Epochs | Evoked - The epochs or evoked object. + inst : Epochs | Evoked | SourceEstimate | list of SourceEstimate + The object to be computed. Can be Epochs, Evoked, any type of + SourceEstimate, or a list of multiple SourceEstimates of the same type. freqs : ndarray, shape (n_freqs,) The frequencies in Hz. n_cycles : float | ndarray, shape (n_freqs,) @@ -827,6 +1093,7 @@ def tfr_multitaper(inst, freqs, n_cycles, time_bandwidth=4.0, .. note:: Decimation may create aliasing artifacts. %(n_jobs)s + Will be ignored for list input. %(picks_good_data)s average : bool, default True If True average across Epochs. @@ -836,9 +1103,9 @@ def tfr_multitaper(inst, freqs, n_cycles, time_bandwidth=4.0, Returns ------- - power : AverageTFR | EpochsTFR + power : AverageTFR | EpochsTFR | SourceTFR The averaged or single-trial power. - itc : AverageTFR | EpochsTFR + itc : AverageTFR | EpochsTFR | SourceTFR The inter-trial coherence (ITC). Only returned if return_itc is True. @@ -2145,19 +2412,48 @@ def combine_tfr(all_tfr, weights='nave'): # Utils -def _get_data(inst, return_itc): +def _get_data(inst, return_itc, fill_dims=True): """Get data from Epochs or Evoked instance as epochs x ch x time.""" from ..epochs import BaseEpochs from ..evoked import Evoked - if not isinstance(inst, (BaseEpochs, Evoked)): - raise TypeError('inst must be Epochs or Evoked') + from ..source_estimate import (_BaseSourceEstimate, VectorSourceEstimate, + VolVectorSourceEstimate) + + if not isinstance(inst, (BaseEpochs, Evoked, _BaseSourceEstimate)): + raise TypeError('inst must be Epochs, Evoked, or any SourceEstimate') + + kern = None if isinstance(inst, BaseEpochs): data = inst.get_data() else: if return_itc: - raise ValueError('return_itc must be False for evoked data') - data = inst.data[np.newaxis].copy() - return data + raise ValueError('return_itc must be False for evoked data ' + 'or single SourceEstimates') + + if isinstance(inst, _BaseSourceEstimate): + # get the data array + if inst._sens_data is not None: + data = inst._sens_data + kern = inst._kernel + # combine the dipole and orientation dims for vector oris + if isinstance(inst, (VectorSourceEstimate, + VolVectorSourceEstimate)): + kern = np.reshape(kern, [kern.shape[0] * kern.shape[1], + kern.shape[2]]) + else: + data = inst.data + # combine the dipole and orientation dimensions for vector oris + if isinstance(inst, (VectorSourceEstimate, + VolVectorSourceEstimate)): + data = np.reshape(data, [data.shape[0] * data.shape[1], + data.shape[2]]) + + if fill_dims: + data = inst.data[np.newaxis] + else: # is Evoked + data = inst.data[np.newaxis].copy() + + return data, kern def _prepare_picks(info, data, picks, axis): diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 7f10c2acb3c..710fb9ab04d 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -2096,7 +2096,7 @@ def plot_and_correct(*args, **kwargs): def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', time_label='auto', smoothing_steps=10, - transparent=None, brain_alpha=0.4, + transparent=True, brain_alpha=0.4, overlay_alpha=None, vector_alpha=1.0, scale_factor=None, time_viewer=False, subjects_dir=None, figure=None, views='lat',