diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 2d731c83364..c0a645b5a86 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -41,6 +41,7 @@ Enhancements - Allow an image with intracranial electrode contacts (e.g. computed tomography) to be used without the freesurfer recon-all surfaces to locate contacts so that it doesn't have to be downsampled to freesurfer dimensions (for microelectrodes) and show an example :ref:`ex-ieeg-micro` with :func:`mne.transforms.apply_volume_registration_points` added to aid this transform (:gh:`11567` by `Alex Rockhill`_) - Use new :meth:`dipy.workflows.align.DiffeomorphicMap.transform_points` to transform a montage of intracranial contacts more efficiently (:gh:`11572` by `Alex Rockhill`_) - Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_) +- Add to :ref:`ex-source-loc-methods` how to apply inverse methods to time-frequency resolved epochs and use :func:`mne.gui.view_vol_stc` to view the output (:gh:`11352` by `Alex Rockhill`_) Bugs ~~~~ diff --git a/doc/conf.py b/doc/conf.py index 2e6680a6b05..150fb0d3c51 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -271,6 +271,7 @@ # unlinkable 'CoregistrationUI', 'IntracranialElectrodeLocator', + 'VolSourceEstimateViewer', 'mne_qt_browser.figure.MNEQtBrowser', } numpydoc_validate = True diff --git a/doc/mri.rst b/doc/mri.rst index 420711eea89..2ef9fcb4d72 100644 --- a/doc/mri.rst +++ b/doc/mri.rst @@ -21,6 +21,7 @@ Step by step instructions for using :func:`gui.coregistration`: get_montage_volume_labels gui.coregistration gui.locate_ieeg + gui.view_vol_stc create_default_subject head_to_mni head_to_mri diff --git a/examples/inverse/evoked_ers_source_power.py b/examples/inverse/evoked_ers_source_power.py index 0ded1fc7aff..bdbc68832a0 100644 --- a/examples/inverse/evoked_ers_source_power.py +++ b/examples/inverse/evoked_ers_source_power.py @@ -12,6 +12,7 @@ """ # Authors: Luke Bloy # Eric Larson +# Alex Rockhill # # License: BSD-3-Clause @@ -21,7 +22,7 @@ import mne from mne.cov import compute_covariance from mne.datasets import somato -from mne.time_frequency import csd_morlet +from mne.time_frequency import csd_tfr from mne.beamformer import (make_dics, apply_dics_csd, make_lcmv, apply_lcmv_cov) from mne.minimum_norm import (make_inverse_operator, apply_inverse_cov) @@ -30,8 +31,10 @@ # %% # Reading the raw data and creating epochs: + data_path = somato.data_path() subject = '01' +subjects_dir = data_path / 'derivatives' / 'freesurfer' / 'subjects' task = 'somato' raw_fname = (data_path / 'sub-{}'.format(subject) / 'meg' / 'sub-{}_task-{}_meg.fif'.format(subject, task)) @@ -52,15 +55,13 @@ preload=True, decim=3) # Read forward operator and point to freesurfer subject directory -fname_fwd = (data_path / 'derivatives' / 'sub-{}'.format(subject) / +fwd_fname = (data_path / 'derivatives' / 'sub-{}'.format(subject) / 'sub-{}_task-{}-fwd.fif'.format(subject, task)) -subjects_dir = data_path / 'derivatives' / 'freesurfer' / 'subjects' - -fwd = mne.read_forward_solution(fname_fwd) +fwd = mne.read_forward_solution(fwd_fname) # %% -# Compute covariances -# ------------------- +# Compute covariances and cross-spectral density +# ---------------------------------------------- # ERS activity starts at 0.5 seconds after stimulus onset. Because these # data have been processed by MaxFilter directly (rather than MNE-Python's # version), we have to be careful to compute the rank with a more conservative @@ -69,17 +70,33 @@ # will be correctly preserved. rank = mne.compute_rank(epochs, tol=1e-6, tol_kind='relative') -active_win = (0.5, 1.5) -baseline_win = (-1, 0) -baseline_cov = compute_covariance(epochs, tmin=baseline_win[0], - tmax=baseline_win[1], method='shrunk', +win_active = (0.5, 1.5) +win_baseline = (-1, 0) +cov_baseline = compute_covariance(epochs, tmin=win_baseline[0], + tmax=win_baseline[1], method='shrunk', rank=rank, verbose=True) -active_cov = compute_covariance(epochs, tmin=active_win[0], tmax=active_win[1], +cov_active = compute_covariance(epochs, tmin=win_active[0], tmax=win_active[1], method='shrunk', rank=rank, verbose=True) -# Weighted averaging is already in the addition of covariance objects. -common_cov = baseline_cov + active_cov -mne.viz.plot_cov(baseline_cov, epochs.info) +# when the covariance objects are added together, they are scaled by the size +# of the window used to create them so that the average is properly weighted +cov_common = cov_baseline + cov_active +cov_baseline.plot(epochs.info) + +freqs = np.logspace(np.log10(12), np.log10(30), 9) + +# time-frequency decomposition +epochs_tfr = mne.time_frequency.tfr_morlet( + epochs, freqs=freqs, n_cycles=freqs / 2, return_itc=False, + average=False, output='complex') +epochs_tfr.decimate(20) # decimate for speed + +# compute cross-spectral density matrices +csd = csd_tfr(epochs_tfr, tmin=-1, tmax=1.5) +csd_baseline = csd_tfr(epochs_tfr, tmin=win_baseline[0], tmax=win_baseline[1]) +csd_ers = csd_tfr(epochs_tfr, tmin=win_active[0], tmax=win_active[1]) + +csd_baseline.plot() # %% # Compute some source estimates @@ -89,13 +106,7 @@ # See :ref:`ex-inverse-source-power` for more information about DICS. -def _gen_dics(active_win, baseline_win, epochs): - freqs = np.logspace(np.log10(12), np.log10(30), 9) - csd = csd_morlet(epochs, freqs, tmin=-1, tmax=1.5, decim=20) - csd_baseline = csd_morlet(epochs, freqs, tmin=baseline_win[0], - tmax=baseline_win[1], decim=20) - csd_ers = csd_morlet(epochs, freqs, tmin=active_win[0], tmax=active_win[1], - decim=20) +def _gen_dics(csd, ers_csd, csd_baseline, fwd): filters = make_dics(epochs.info, fwd, csd.mean(), pick_ori='max-power', reduce_rank=True, real_filter=True, rank=rank) stc_base, freqs = apply_dics_csd(csd_baseline.mean(), filters) @@ -105,30 +116,30 @@ def _gen_dics(active_win, baseline_win, epochs): # generate lcmv source estimate -def _gen_lcmv(active_cov, baseline_cov, common_cov): +def _gen_lcmv(active_cov, cov_baseline, common_cov, fwd): filters = make_lcmv(epochs.info, fwd, common_cov, reg=0.05, noise_cov=None, pick_ori='max-power') - stc_base = apply_lcmv_cov(baseline_cov, filters) - stc_act = apply_lcmv_cov(active_cov, filters) + stc_base = apply_lcmv_cov(cov_baseline, filters) + stc_act = apply_lcmv_cov(cov_active, filters) stc_act /= stc_base return stc_act # generate mne/dSPM source estimate -def _gen_mne(active_cov, baseline_cov, common_cov, fwd, info, method='dSPM'): - inverse_operator = make_inverse_operator(info, fwd, common_cov) - stc_act = apply_inverse_cov(active_cov, info, inverse_operator, +def _gen_mne(cov_active, cov_baseline, cov_common, fwd, info, method='dSPM'): + inverse_operator = make_inverse_operator(info, fwd, cov_common) + stc_act = apply_inverse_cov(cov_active, info, inverse_operator, method=method, verbose=True) - stc_base = apply_inverse_cov(baseline_cov, info, inverse_operator, + stc_base = apply_inverse_cov(cov_baseline, info, inverse_operator, method=method, verbose=True) stc_act /= stc_base return stc_act # Compute source estimates -stc_dics = _gen_dics(active_win, baseline_win, epochs) -stc_lcmv = _gen_lcmv(active_cov, baseline_cov, common_cov) -stc_dspm = _gen_mne(active_cov, baseline_cov, common_cov, fwd, epochs.info) +stc_dics = _gen_dics(csd, csd_ers, csd_baseline, fwd) +stc_lcmv = _gen_lcmv(cov_active, cov_baseline, cov_common, fwd) +stc_dspm = _gen_mne(cov_active, cov_baseline, cov_common, fwd, epochs.info) # %% # Plot source estimates @@ -152,3 +163,56 @@ def _gen_mne(active_cov, baseline_cov, common_cov, fwd, info, method='dSPM'): brain_dspm = stc_dspm.plot( hemi='rh', subjects_dir=subjects_dir, subject=subject, time_label='dSPM source power in the 12-30 Hz frequency band') + +# %% +# Use volume source estimate with time-frequency resolution +# --------------------------------------------------------- + +# make a volume source space +surface = subjects_dir / subject / 'bem' / 'inner_skull.surf' +vol_src = mne.setup_volume_source_space( + subject=subject, subjects_dir=subjects_dir, surface=surface, + pos=10, add_interpolator=False) # just for speed! + +conductivity = (0.3,) # one layer for MEG +model = mne.make_bem_model(subject=subject, ico=3, # just for speed + conductivity=conductivity, + subjects_dir=subjects_dir) +bem = mne.make_bem_solution(model) + +trans = fwd['info']['mri_head_t'] +vol_fwd = mne.make_forward_solution( + raw.info, trans=trans, src=vol_src, bem=bem, meg=True, eeg=True, + mindist=5.0, n_jobs=1, verbose=True) + +# Compute source estimate using MNE solver +snr = 3.0 +lambda2 = 1.0 / snr ** 2 +method = 'MNE' # use MNE method (could also be dSPM or sLORETA) + +# make a different inverse operator for each frequency so as to properly +# whiten the sensor data +inverse_operator = list() +for freq_idx in range(epochs_tfr.freqs.size): + # for each frequency, compute a separate covariance matrix + cov_baseline = csd_baseline.get_data(index=freq_idx, as_cov=True) + cov_baseline['data'] = cov_baseline['data'].real # only normalize by real + # then use that covariance matrix as normalization for the inverse + # operator + inverse_operator.append(mne.minimum_norm.make_inverse_operator( + epochs.info, vol_fwd, cov_baseline)) + +# finally, compute the stcs for each epoch and frequency +stcs = mne.minimum_norm.apply_inverse_tfr_epochs( + epochs_tfr, inverse_operator, lambda2, method=method, + pick_ori='vector') + +# %% +# Plot volume source estimates +# ---------------------------- + +viewer = mne.gui.view_vol_stc(stcs, subject=subject, subjects_dir=subjects_dir, + src=vol_src, inst=epochs_tfr) +viewer.go_to_extreme() # show the maximum intensity source vertex +viewer.set_cmap(vmin=0.25, vmid=0.8) +viewer.set_3d_view(azimuth=40, elevation=35, distance=350) diff --git a/mne/gui/__init__.py b/mne/gui/__init__.py index c86b413b634..088183c59cf 100644 --- a/mne/gui/__init__.py +++ b/mne/gui/__init__.py @@ -1,9 +1,11 @@ """Convenience functions for opening GUIs.""" # Authors: Christian Brodbeck +# Alex Rockhill # # License: BSD-3-Clause +import numpy as np from ..utils import verbose, get_config, warn @@ -234,13 +236,11 @@ def locate_ieeg(info, trans, base_image, subject=None, subjects_dir=None, gui : instance of IntracranialElectrodeLocator The graphical user interface (GUI) window. """ - from ..viz.backends._utils import _qt_app_exec + from ..viz.backends._utils import _init_mne_qtapp, _qt_app_exec from ._ieeg_locate import IntracranialElectrodeLocator - from qtpy.QtWidgets import QApplication - # get application - app = QApplication.instance() - if app is None: - app = QApplication(["Intracranial Electrode Locator"]) + + app = _init_mne_qtapp() + gui = IntracranialElectrodeLocator( info, trans, base_image, subject=subject, subjects_dir=subjects_dir, groups=groups, show=show, verbose=verbose) @@ -249,6 +249,103 @@ def locate_ieeg(info, trans, base_image, subject=None, subjects_dir=None, return gui +@verbose +def view_vol_stc(stcs, freq_first=True, subject=None, subjects_dir=None, + src=None, inst=None, use_int=True, show_topomap=True, + show=True, block=False, verbose=None): + """View a volume time and/or frequency source time course estimate. + + Parameters + ---------- + stcs : list of list | generator + The source estimates, the options are: 1) List of lists or + generators for epochs and frequencies (i.e. using + :func:`mne.minimum_norm.apply_inverse_tfr_epochs` or + :func:`mne.beamformer.apply_dics_tfr_epochs`-- in this case + use ``freq_first=False``), or 2) List of source estimates across + frequencies (e.g. :func::func:`mne.beamformer.apply_dics_csd`), + or 3) list of source estimates across epochs + (e.g. :func:`mne.minimum_norm.apply_inverse_epochs` and + :func:`mne.beamformer.apply_dics_epochs`--in this + case use ``freq_first=False``), or 4) Single + source estimates (e.g. :func:`mne.minimum_norm.apply_inverse` + and :func:`mne.beamformer.apply_dics`, note ``freq_first`` + will not be used in this case). + freq_first : bool + If frequencies are the outer list of ``stcs`` use ``True``. + %(subject)s + %(subjects_dir)s + src : instance of SourceSpaces + The volume source space for the ``stc``. + inst : EpochsTFR | AverageTFR | None + The time-frequency or data object to use to plot topography. + use_int : bool + If ``True``, cast the data to integers to reduce memory use. + show_topomap : bool + Whether to show the sensor topomap in the GUI. + show : bool + Show the GUI if True. + block : bool + Whether to halt program execution until the figure is closed. + %(verbose)s + + Returns + ------- + gui : instance of VolSourceEstimateViewer + The graphical user interface (GUI) window. + """ + from ..viz.backends._utils import _init_mne_qtapp, _qt_app_exec + from ._vol_stc import (VolSourceEstimateViewer, COMPLEX_DTYPE, + RANGE_VALUE, BASE_INT_DTYPE) + + app = _init_mne_qtapp() + + # cast to integers to lower memory usage, use custom complex data + # type if necessary + data = list() + # can be generator, compute using first stc object, just a general + # rescaling of data, does not need to be precise + scalar = None + for inner_stcs in (stcs if np.iterable(stcs) else [stcs]): + inner_data = list() + for stc in (inner_stcs if np.iterable(inner_stcs) else [inner_stcs]): + if use_int: + if scalar is None: + # this is an order of magnitude approximation, + # if another stc is 10x larger than the first one, + # it will have some clipping + scalar = (RANGE_VALUE - 1) / stc.data.real.max() / 10 + if np.iscomplexobj(stc.data): + stc_data = np.zeros(stc.data.shape, COMPLEX_DTYPE) + stc_data['re'] = np.clip(stc.data.real * scalar, + -RANGE_VALUE, RANGE_VALUE - 1) + stc_data['im'] = np.clip(stc.data.imag * scalar, + -RANGE_VALUE, RANGE_VALUE - 1) + inner_data.append(stc_data) + else: + inner_data.append(np.clip(stc.data * scalar, + -RANGE_VALUE, RANGE_VALUE - 1 + ).astype(BASE_INT_DTYPE)) + else: + inner_data.append(stc.data) + data.append(inner_data) + + data = np.array(data) + if data.ndim == 4: # scalar solution, add dimension at the end + data = data[:, :, :, None] + + # move frequencies to penultimate + data = data.transpose((1, 2, 3, 0, 4) if freq_first else (0, 2, 3, 1, 4)) + + gui = VolSourceEstimateViewer( + data, subject=subject, subjects_dir=subjects_dir, + src=src, inst=inst, show_topomap=show_topomap, show=show, + verbose=verbose) + if block: + _qt_app_exec(app) + return gui + + class _GUIScraper: """Scrape GUI outputs.""" @@ -258,11 +355,13 @@ def __repr__(self): def __call__(self, block, block_vars, gallery_conf): from ._ieeg_locate import IntracranialElectrodeLocator from ._coreg import CoregistrationUI + from ._vol_stc import VolSourceEstimateViewer from sphinx_gallery.scrapers import figure_rst from qtpy import QtGui for gui in block_vars['example_globals'].values(): if (isinstance(gui, (IntracranialElectrodeLocator, - CoregistrationUI)) and + CoregistrationUI, + VolSourceEstimateViewer)) and not getattr(gui, '_scraped', False) and gallery_conf['builder_name'] == 'html'): gui._scraped = True # monkey-patch but it's easy enough diff --git a/mne/gui/_ieeg_locate.py b/mne/gui/_ieeg_locate.py index a23590d7317..e9e96efa722 100644 --- a/mne/gui/_ieeg_locate.py +++ b/mne/gui/_ieeg_locate.py @@ -382,6 +382,7 @@ def _configure_status_bar(self, hbox=None): def _move_cursors_to_pos(self): super(IntracranialElectrodeLocator, self)._move_cursors_to_pos() + self._ch_list.setFocus() # remove focus from text edit def _group_channels(self, groups): diff --git a/mne/gui/_vol_stc.py b/mne/gui/_vol_stc.py new file mode 100644 index 00000000000..19f031577f5 --- /dev/null +++ b/mne/gui/_vol_stc.py @@ -0,0 +1,1202 @@ +# -*- coding: utf-8 -*- +"""Source estimate viewing graphical user interfaces (GUIs).""" + +# Authors: Alex Rockhill +# +# License: BSD (3-clause) + +import os.path as op +import numpy as np + +from qtpy import QtCore +from qtpy.QtWidgets import (QVBoxLayout, QHBoxLayout, QLabel, + QMessageBox, QWidget, QSlider, QPushButton, + QComboBox, QLineEdit, QFrame) +from matplotlib.colors import LinearSegmentedColormap + +from ._core import SliceBrowser +from .. import BaseEpochs +from ..baseline import rescale, _check_baseline +from ..defaults import DEFAULTS +from ..evoked import EvokedArray +from ..time_frequency import EpochsTFR +from ..io.constants import FIFF +from ..io.pick import _get_channel_types, _picks_to_idx, _pick_inst +from ..transforms import apply_trans +from ..utils import (_require_version, _validate_type, _check_range, fill_doc, + _check_option) +from ..viz.backends._utils import _qt_safe_window +from ..viz.utils import _get_cmap + +BASE_INT_DTYPE = np.int16 +COMPLEX_DTYPE = np.dtype([('re', BASE_INT_DTYPE), + ('im', BASE_INT_DTYPE)]) +RANGE_VALUE = 2**15 +# for taking the complex conjugate, we need to be able to +# temporarily store in a value where x**2 * 2 fits +OVERFLOW_DYPE = np.int32 + +VECTOR_SCALAR = 10 +SLIDER_WIDTH = 300 + + +def _check_consistent(items, name): + if not len(items): + return + for item in items[1:]: + if item != items[0]: + raise RuntimeError(f'Inconsistent attribute {name}, ' + f'got {items[0]} and {item}') + return items[0] + + +def _get_src_lut(src): + offset = 2 if src.kind == 'mixed' else 0 + inuse = [s['inuse'] for s in src[offset:]] + rr = np.concatenate( + [s['rr'][this_inuse.astype(bool)] + for s, this_inuse in zip(src[offset:], inuse)]) + shape = _check_consistent([this_src['shape'] for this_src in src], + "src['shape']") + # order='F' so that F-order flattening is faster + lut = -1 * np.ones(np.prod(shape), dtype=np.int64, order='F') + n_vertices_seen = 0 + for this_inuse in inuse: + this_inuse = this_inuse.astype(bool) + n_vertices = np.sum(this_inuse) + lut[this_inuse] = np.arange( + n_vertices_seen, n_vertices_seen + n_vertices) + n_vertices_seen += n_vertices + lut = np.reshape(lut, shape, order='F') + src_affine_ras = _check_consistent( + [this_src['mri_ras_t']['trans'] for this_src in src], + "src['mri_ras_t']") + src_affine_src = _check_consistent( + [this_src['src_mri_t']['trans'] for this_src in src], + "src['src_mri_t']") + affine = np.dot(src_affine_ras, src_affine_src) + affine[:3] *= 1e3 + return lut, affine, src_affine_src * 1000, rr * 1000 + + +def _make_vol(lut, stc_data): + vol = np.zeros(lut.shape, dtype=stc_data.dtype, order='F') * np.nan + vol[lut >= 0] = stc_data[lut[lut >= 0]] + return vol.reshape(lut.shape, order='F') + + +def _coord_to_coord(coord, vox_ras_t, ras_vox_t): + return apply_trans(ras_vox_t, apply_trans(vox_ras_t, coord)) + + +def _threshold_array(array, min_val, max_val): + array = array.astype(float) + array[array < min_val] = np.nan + array[array > max_val] = np.nan + return array + + +def _int_complex_conj(data): + # Since the mixed real * imaginary terms cancel out, the complex + # conjugate is the same as squaring and adding the real and imaginary. + # Case up the integer size temporarily to prevent overflow + conj = (data['re'].astype(OVERFLOW_DYPE))**2 + \ + (data['im'].astype(OVERFLOW_DYPE))**2 + return (conj // (conj.max() // RANGE_VALUE + 1)).astype(BASE_INT_DTYPE) + + +class VolSourceEstimateViewer(SliceBrowser): + """View a source estimate time-course time-frequency visualization.""" + + @_qt_safe_window(splash='_renderer.figure.splash', window='') + def __init__(self, data, subject=None, subjects_dir=None, src=None, + inst=None, show_topomap=True, show=True, verbose=None): + """View a volume time and/or frequency source time course estimate. + + Parameters + ---------- + data : array-like + An array of shape (``n_epochs``, ``n_sources``, ``n_ori``, + ``n_freqs``, ``n_times``). ``n_epochs`` may be 1 for data + averaged across epochs and ``n_freqs`` may be 1 for data + that is in time only and is not time-frequency decomposed. For + faster performance, data can be cast to integers or a + custom complex data type that uses integers as done by + :func:`mne.gui.view_vol_stc`. + %(subject)s + %(subjects_dir)s + src : instance of SourceSpaces + The volume source space for the ``stc``. + inst : EpochsTFR | AverageTFR | None + The time-frequency or data object to use to plot topography. + show_topomap : bool + Show the sensor topomap if ``True``. + show : bool + Show the GUI if ``True``. + block : bool + Whether to halt program execution until the figure is closed. + %(verbose)s + """ + _require_version('dipy', 'VolSourceEstimateViewer', '0.10.1') + if src is None: + raise NotImplementedError('`src` is required, surface source ' + 'estimate viewing is not yet supported') + if inst is None: + raise NotImplementedError( + '`data` as a source estimate object is ' + 'not yet supported so `inst` is required') + if not isinstance(data, np.ndarray) or data.ndim != 5: + raise ValueError('`data` must be an array of dimensions ' + '(n_epochs, n_sources, n_ori, n_freqs, n_times)') + if isinstance(inst, (BaseEpochs, EpochsTFR)) and \ + data.shape[0] != len(inst): + raise ValueError( + 'Number of epochs in `inst` does not match with `data`, ' + f'expected {data.shape[0]}, got {len(inst)}') + n_src_verts = sum([this_src['nuse'] for this_src in src]) + if src is not None and data.shape[1] != n_src_verts: + raise RuntimeError('Source vertices in `data` do not match with ' + 'source space vertices in `src`, ' + f'expected {n_src_verts}, got {data.shape[1]}') + if any([this_src['type'] == 'surf' for this_src in src]): + raise NotImplementedError('Surface and mixed source space ' + 'viewing is not implemented yet.') + if not all([s['coord_frame'] == FIFF.FIFFV_COORD_MRI for s in src]): + raise RuntimeError('The source space must be in the `mri`' + 'coordinate frame') + if hasattr(inst, 'freqs') and data.shape[3] != inst.freqs.size: + raise ValueError( + 'Frequencies in `inst` do not match with `data`, ' + f'expected {data.shape[3]}, got {inst.freqs.size}') + if hasattr(inst, 'freqs') and not (np.iscomplexobj(data) or + data.dtype == COMPLEX_DTYPE): + raise ValueError('Complex data is required for time-frequency ' + 'source estimates') + if data.shape[4] != inst.times.size: + raise ValueError( + 'Times in `inst` do not match with `data`, ' + f'expected {data.shape[4]}, got {inst.times.size}') + self._verbose = verbose # used for logging, unused here + self._data = data + self._src = src + self._inst = inst + self._show_topomap = show_topomap + (self._src_lut, self._src_vox_scan_ras_t, self._src_vox_ras_t, + self._src_rr) = _get_src_lut(src) + self._src_scan_ras_vox_t = np.linalg.inv(self._src_vox_scan_ras_t) + self._is_complex = np.iscomplexobj(self._data) or \ + self._data.dtype == COMPLEX_DTYPE + self._baseline = 'none' + self._bl_tmin = self._inst.times[0] + self._bl_tmax = self._inst.times[-1] + self._update = True # can be set to False to prevent double updates + # for time and frequency + # check if only positive values will be used + self._pos_support = self._is_complex or self._data.shape[2] > 1 or \ + (self._data >= 0).all() + self._cmap = _get_cmap('hot' if self._pos_support else 'mne') + + # set default variables for plotting + self._t_idx = self._inst.times.size // 2 + self._f_idx = self._inst.freqs.size // 2 \ + if hasattr(self._inst, 'freqs') else None + self._alpha = 0.75 + self._epoch_idx = 'Average' + ' Power' * self._is_complex + + # initialize current 3D image for chosen time and frequency + stc_data = self._pick_epoch(self._data) + + # take the vector magnitude, if scalar, does nothing + self._stc_data_vol = np.linalg.norm(stc_data, axis=1) + + stc_max = np.nanmax(self._stc_data_vol) + self._stc_min = min([np.nanmin(self._stc_data_vol), stc_max]) + self._stc_range = max([stc_max, -self._stc_min]) - self._stc_min + + stc_data_vol = self._pick_stc_tfr(self._stc_data_vol) + self._stc_img = _make_vol(self._src_lut, stc_data_vol) + + super(VolSourceEstimateViewer, self).__init__( + subject=subject, subjects_dir=subjects_dir) + + if src._subject != op.basename(self._subject_dir): + raise RuntimeError( + f'Source space subject ({src._subject})-freesurfer subject' + f'({op.basename(self._subject_dir)}) mismatch') + + # make source time course plots + self._images['stc'] = list() + src_shape = np.array(self._src_lut.shape) + corners = [ # center pixel on location + _coord_to_coord( + (0,) * 3, self._src_vox_scan_ras_t, self._scan_ras_vox_t), + _coord_to_coord( + src_shape - 1, self._src_vox_scan_ras_t, self._scan_ras_vox_t) + ] + src_coord = self._get_src_coord() + for axis in range(3): + stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T + x_idx, y_idx = self._xy_idx[axis] + extent = [corners[0][x_idx], corners[1][x_idx], + corners[1][y_idx], corners[0][y_idx]] + self._images['stc'].append(self._figs[axis].axes[0].imshow( + stc_slice, aspect='auto', extent=extent, cmap=self._cmap, + alpha=self._alpha, zorder=2)) + + self._data_max = abs(stc_data).max() + if self._data.shape[2] > 1 and not self._is_complex: + # also compute vectors for chosen time + self._stc_vectors = self._pick_stc_tfr(stc_data).astype(float) + self._stc_vectors /= self._data_max + self._stc_vectors_masked = self._stc_vectors.copy() + + assert self._data.shape[2] == 3 + self._vector_mapper, self._vector_data = self._renderer.quiver3d( + *self._src_rr.T, *(VECTOR_SCALAR * self._stc_vectors_masked.T), + color=None, mode='2darrow', scale_mode='vector', scale=1, + opacity=1) + self._vector_actor = self._renderer._actor(self._vector_mapper) + self._vector_actor.GetProperty().SetLineWidth(2.) + self._renderer.plotter.add_actor(self._vector_actor, render=False) + + # initialize 3D volumetric rendering + # TO DO: add surface source space viewing as elif + if any([this_src['type'] == 'vol' for this_src in self._src]): + scalars = np.array(np.where(np.isnan(self._stc_img), 0, 1.)) + spacing = np.diag(self._src_vox_ras_t)[:3] + origin = self._src_vox_ras_t[:3, 3] - spacing / 2. + center = 0.5 * self._stc_range - self._stc_min + self._grid, self._grid_mesh, self._volume_pos, self._volume_neg = \ + self._renderer._volume( + dimensions=src_shape, origin=origin, + spacing=spacing, + scalars=scalars.flatten(order='F'), + surface_alpha=self._alpha, + resolution=0.4, blending='mip', center=center) + self._volume_pos_actor = self._renderer.plotter.add_actor( + self._volume_pos, render=False)[0] + self._volume_neg_actor = self._renderer.plotter.add_actor( + self._volume_neg, render=False)[0] + _, grid_prop = self._renderer.plotter.add_actor( + self._grid_mesh, render=False) + grid_prop.SetOpacity(0.1) + self._scalar_bar = self._renderer.scalarbar( + source=self._volume_pos_actor, n_labels=8, color='black', + bgcolor='white', label_font_size=10) + self._scalar_bar.SetOrientationToVertical() + self._scalar_bar.SetHeight(0.6) + self._scalar_bar.SetWidth(0.05) + self._scalar_bar.SetPosition(0.02, 0.2) + + self._update_cmap() # must be called for volume to render properly + # keep focus on main window so that keypress events work + self.setFocus() + if show: + self.show() + + def _get_min_max_val(self): + """Get the minimum and maximum non-transparent values.""" + return [self._cmap_sliders[i].value() / SLIDER_WIDTH * + self._stc_range + self._stc_min for i in (0, 2)] + + def _get_src_coord(self): + """Get the current slice transformed to source space.""" + return tuple(np.round(_coord_to_coord( + self._current_slice, self._vox_scan_ras_t, + self._src_scan_ras_vox_t)).astype(int)) + + def _update_stc_pick(self): + """Update the normalized data with the epoch picked.""" + stc_data = self._pick_epoch(self._data) + self._stc_data_vol = self._apply_vector_norm(stc_data) + self._stc_data_vol = self._apply_baseline_correction( + self._stc_data_vol) + # deal with baseline infinite numbers + inf_mask = np.isinf(self._stc_data_vol) + if inf_mask.any(): + self._stc_data_vol[inf_mask] = np.nan + stc_max = np.nanmax(self._stc_data_vol) + self._stc_min = min([np.nanmin(self._stc_data_vol), -stc_max]) + self._stc_range = max([stc_max, -self._stc_min]) - self._stc_min + + def _update_vectors(self): + if self._data.shape[2] > 1 and not self._is_complex: + # pick vector as well + self._stc_vectors = self._pick_stc_tfr(self._data) + self._stc_vectors = self._pick_epoch( + self._stc_vectors).astype(float) + self._stc_vectors /= self._data_max + self._update_vector_threshold() + self._plot_vectors() + + def _update_vector_threshold(self): + """Update the threshold for the vectors.""" + # apply threshold, use same mask as for stc_img + stc_data = self._pick_stc_tfr(self._stc_data_vol) + min_val, max_val = self._get_min_max_val() + self._stc_vectors_masked = self._stc_vectors.copy() + self._stc_vectors_masked[stc_data < min_val] = np.nan + self._stc_vectors_masked[stc_data > max_val] = np.nan + + def _update_stc_volume(self): + """Select volume based on the current time, frequency and vertex.""" + stc_data = self._pick_stc_tfr(self._stc_data_vol) + self._stc_img = _make_vol(self._src_lut, stc_data) + self._stc_img = _threshold_array( + self._stc_img, *self._get_min_max_val()) + + def _update_stc_all(self): + """Update the data in both the slice plots and the data plot.""" + # pick new epochs + baseline correction combination + self._update_stc_pick() + self._update_stc_images() # and then make the new volume + self._update_intensity() + self._update_cmap() # note: this updates stc slice plots + self._plot_data() + if self._show_topomap and self._update: + self._plot_topomap() + + def _pick_stc_image(self): + """Select time-(frequency) image based on vertex.""" + return self._pick_stc_vertex(self._stc_data_vol) + + def _pick_epoch(self, stc_data): + """Select the source time course epoch based on the parameters.""" + if self._epoch_idx == 'Average': + if stc_data.dtype == BASE_INT_DTYPE: + stc_data = stc_data.mean(axis=0).astype(BASE_INT_DTYPE) + else: + stc_data = stc_data.mean(axis=0) + elif self._epoch_idx == 'Average Power': + if stc_data.dtype == COMPLEX_DTYPE: + stc_data = np.sum(_int_complex_conj( + stc_data) // stc_data.shape[0], axis=0, + dtype=BASE_INT_DTYPE) + else: + stc_data = (stc_data * stc_data.conj()).real.mean(axis=0) + elif self._epoch_idx == 'ITC': + if stc_data.dtype == COMPLEX_DTYPE: + stc_data = stc_data['re'].astype(np.complex64) + \ + 1j * stc_data['im'] + stc_data = np.abs((stc_data / np.abs(stc_data)).mean(axis=0)) + else: + stc_data = np.abs((stc_data / np.abs(stc_data)).mean(axis=0)) + else: + stc_data = stc_data[int(self._epoch_idx.replace('Epoch ', ''))] + if stc_data.dtype == COMPLEX_DTYPE: + stc_data = _int_complex_conj(stc_data) + elif self._is_complex: + stc_data = (stc_data * stc_data.conj()).real + return stc_data + + def _apply_vector_norm(self, stc_data, axis=1): + """Take the vector norm if source data is vector.""" + if self._epoch_idx == 'ITC': + stc_data = np.max(stc_data, axis=axis) # take maximum ITC + elif stc_data.shape[axis] > 1: + stc_data = np.linalg.norm(stc_data, axis=axis) # take magnitude + # if self._data.dtype in (COMPLEX_DTYPE, BASE_INT_DTYPE): + # stc_data = stc_data.round().astype(BASE_INT_DTYPE) + else: + stc_data = np.take(stc_data, 0, axis=axis) + return stc_data + + def _apply_baseline_correction(self, stc_data): + """Apply the chosen baseline correction to the data.""" + if self._baseline != 'none': # do baseline correction + stc_data = rescale( + stc_data.astype(float), times=self._inst.times, + baseline=(float(self._bl_tmin), float(self._bl_tmax)), + mode=self._baseline, copy=True) + return stc_data + + def _pick_stc_vertex(self, stc_data): + """Select the vertex based on the cursor position.""" + src_coord = self._get_src_coord() + if all([coord >= 0 and coord < dim for coord, dim in zip( + src_coord, self._src_lut.shape)]) and \ + self._src_lut[src_coord] >= 0: + stc_data = stc_data[self._src_lut[src_coord]] + else: # out-of-bounds or unused vertex + stc_data = np.zeros(stc_data[:, 0].shape) * np.nan + return stc_data + + def _pick_stc_tfr(self, stc_data): + """Select the frequency and time based on GUI values.""" + stc_data = np.take(stc_data, self._t_idx, axis=-1) + f_idx = 0 if self._f_idx is None else self._f_idx + stc_data = np.take(stc_data, f_idx, axis=-1) + return stc_data + + def _configure_ui(self): + """Configure the main appearance of the user interface.""" + toolbar = self._configure_toolbar() + slider_bar = self._configure_sliders() + status_bar = self._configure_status_bar() + data_plot = self._configure_data_plot() + + plot_vbox = QVBoxLayout() + plot_vbox.addLayout(self._plt_grid) + + if self._show_topomap: + data_hbox = QHBoxLayout() + topo_plot = self._configure_topo_plot() + data_hbox.addWidget(topo_plot) + data_hbox.addWidget(data_plot) + plot_vbox.addLayout(data_hbox) + else: + plot_vbox.addWidget(data_plot) + + main_hbox = QHBoxLayout() + main_hbox.addLayout(slider_bar) + main_hbox.addLayout(plot_vbox) + + main_vbox = QVBoxLayout() + main_vbox.addLayout(toolbar) + main_vbox.addLayout(main_hbox) + main_vbox.addLayout(status_bar) + + central_widget = QWidget() + central_widget.setLayout(main_vbox) + self.setCentralWidget(central_widget) + + def _configure_toolbar(self): + """Make a bar with buttons for user interactions.""" + hbox = QHBoxLayout() + + help_button = QPushButton('Help') + help_button.released.connect(self._show_help) + hbox.addWidget(help_button) + + hbox.addStretch(8) + + if self._data.shape[0] > 1: + self._epoch_selector = QComboBox() + if self._is_complex: + self._epoch_selector.addItems(['Average Power']) + self._epoch_selector.addItems(['ITC']) + else: + self._epoch_selector.addItems(['Average']) + self._epoch_selector.addItems( + [f'Epoch {i}' for i in range(self._data.shape[0])]) + self._epoch_selector.setCurrentText(self._epoch_idx) + self._epoch_selector.currentTextChanged.connect(self._update_epoch) + self._epoch_selector.setSizeAdjustPolicy( + QComboBox.AdjustToContents) + self._epoch_selector.keyPressEvent = self.keyPressEvent + hbox.addWidget(self._epoch_selector) + + return hbox + + def _show_help(self): + """Show the help menu.""" + QMessageBox.information( + self, 'Help', + "Help:\n" + "'+'/'-': zoom\nleft/right arrow: left/right\n" + "up/down arrow: superior/inferior\n" + "left angle bracket/right angle bracket: anterior/posterior") + + def _configure_sliders(self): + """Make a bar with sliders on it.""" + + def make_label(name): + label = QLabel(name) + label.setAlignment(QtCore.Qt.AlignCenter) + return label + + # modified from: + # https://stackoverflow.com/questions/52689047/moving-qslider-to-mouse-click-position + class Slider(QSlider): + + def mouseReleaseEvent(self, event): + if event.button() == QtCore.Qt.LeftButton: + event.accept() + value = (self.maximum() - self.minimum()) * \ + event.pos().x() / self.width() + self.minimum() + value = np.clip(value, 0, SLIDER_WIDTH) + self.setValue(int(round(value))) + else: + super(Slider, self).mouseReleaseEvent(event) + + def make_slider(smin, smax, sval, sfun=None): + slider = Slider(QtCore.Qt.Horizontal) + slider.setMinimum(int(round(smin))) + slider.setMaximum(int(round(smax))) + slider.setValue(int(round(sval))) + slider.setTracking(False) # only update on release + if sfun is not None: + slider.valueChanged.connect(sfun) + slider.keyPressEvent = self.keyPressEvent + slider.setMinimumWidth(SLIDER_WIDTH) + return slider + + slider_layout = QVBoxLayout() + slider_layout.setContentsMargins(11, 11, 11, 11) # for aesthetics + + if hasattr(self._inst, 'freqs'): + slider_layout.addWidget(make_label('Frequency (Hz)')) + self._freq_slider = make_slider( + 0, self._inst.freqs.size - 1, self._f_idx, self._update_freq) + slider_layout.addWidget(self._freq_slider) + freq_hbox = QHBoxLayout() + freq_hbox.addWidget(make_label(str(self._inst.freqs[0].round(2)))) + freq_hbox.addStretch(1) + freq_hbox.addWidget(make_label(str(self._inst.freqs[-1].round(2)))) + slider_layout.addLayout(freq_hbox) + self._freq_label = make_label( + f'Freq = {self._inst.freqs[self._f_idx].round(2)} Hz') + slider_layout.addWidget(self._freq_label) + slider_layout.addStretch(1) + + slider_layout.addWidget(make_label('Time (s)')) + self._time_slider = make_slider(0, self._inst.times.size - 1, + self._t_idx, self._update_time) + slider_layout.addWidget(self._time_slider) + time_hbox = QHBoxLayout() + time_hbox.addWidget(make_label(str(self._inst.times[0].round(2)))) + time_hbox.addStretch(1) + time_hbox.addWidget(make_label(str(self._inst.times[-1].round(2)))) + slider_layout.addLayout(time_hbox) + self._time_label = make_label( + f'Time = {self._inst.times[self._t_idx].round(2)} s') + slider_layout.addWidget(self._time_label) + slider_layout.addStretch(1) + + slider_layout.addWidget(make_label('Alpha')) + self._alpha_slider = make_slider( + 0, SLIDER_WIDTH, int(self._alpha * SLIDER_WIDTH), + self._update_alpha) + slider_layout.addWidget(self._alpha_slider) + self._alpha_label = make_label(f'Alpha = {self._alpha}') + slider_layout.addWidget(self._alpha_label) + slider_layout.addStretch(1) + + slider_layout.addWidget(make_label('min / mid / max')) + self._cmap_sliders = [ + make_slider(0, SLIDER_WIDTH, 0, self._update_cmap), + make_slider(0, SLIDER_WIDTH, SLIDER_WIDTH // 2, + self._update_cmap), + make_slider(0, SLIDER_WIDTH, SLIDER_WIDTH, self._update_cmap)] + for slider in self._cmap_sliders: + slider_layout.addWidget(slider) + slider_layout.addStretch(1) + + return slider_layout + + def _configure_status_bar(self, hbox=None): + hbox = QHBoxLayout() if hbox is None else hbox + + hbox.addWidget(QLabel('Baseline')) + self._baseline_selector = QComboBox() + self._baseline_selector.addItems(['none', 'mean', 'ratio', 'logratio', + 'percent', 'zscore', 'zlogratio']) + self._baseline_selector.setCurrentText('none') + self._baseline_selector.currentTextChanged.connect( + self._update_baseline) + self._baseline_selector.setSizeAdjustPolicy(QComboBox.AdjustToContents) + self._baseline_selector.keyPressEvent = self.keyPressEvent + hbox.addWidget(self._baseline_selector) + + hbox.addWidget(QLabel('tmin =')) + self._bl_tmin_textbox = QLineEdit(str(round(self._bl_tmin, 2))) + self._bl_tmin_textbox.setMaximumWidth(60) + self._bl_tmin_textbox.focusOutEvent = self._update_baseline_tmin + hbox.addWidget(self._bl_tmin_textbox) + + hbox.addWidget(QLabel('tmax =')) + self._bl_tmax_textbox = QLineEdit(str(round(self._bl_tmax, 2))) + self._bl_tmax_textbox.setMaximumWidth(60) + self._bl_tmax_textbox.focusOutEvent = self._update_baseline_tmax + hbox.addWidget(self._bl_tmax_textbox) + + # add separator for clarity + sep = QFrame() + sep.setFrameShape(QFrame.VLine) + sep.setFrameShadow(QFrame.Sunken) + hbox.addWidget(sep) + + hbox.addStretch(3 if self._f_idx is None else 2) + + if self._show_topomap: + hbox.addWidget(QLabel('Topo Data=')) + self._data_type_selector = QComboBox() + self._data_type_selector.addItems( + _get_channel_types(self._inst.info, picks='data', unique=True)) + self._data_type_selector.currentTextChanged.connect( + self._update_data_type) + self._data_type_selector.setSizeAdjustPolicy( + QComboBox.AdjustToContents) + self._data_type_selector.keyPressEvent = self.keyPressEvent + hbox.addWidget(self._data_type_selector) + hbox.addStretch(1) + + if self._f_idx is not None: + hbox.addWidget(QLabel('Interpolate')) + self._interp_button = QPushButton('On') + self._interp_button.setMaximumWidth(25) # not too big + self._interp_button.setStyleSheet("background-color: green") + hbox.addWidget(self._interp_button) + self._interp_button.released.connect(self._toggle_interp) + hbox.addStretch(1) + + self._go_to_extreme_button = QPushButton('Go to Max') + self._go_to_extreme_button.released.connect(self.go_to_extreme) + hbox.addWidget(self._go_to_extreme_button) + hbox.addStretch(2) + + self._intensity_label = QLabel('') # update later + hbox.addWidget(self._intensity_label) + + # add SliceBrowser navigation items + hbox = super(VolSourceEstimateViewer, self)._configure_status_bar( + hbox=hbox) + return hbox + + def _configure_data_plot(self): + """Configure the plot that shows spectrograms/time-courses.""" + from ._core import _make_mpl_plot + canvas, self._fig = _make_mpl_plot( + dpi=96, tight=False, hide_axes=False, invert=False, + facecolor='white') + self._fig.axes[0].set_position([0.12, 0.25, 0.73, 0.7]) + self._fig.axes[0].set_xlabel('Time (s)') + min_idx = np.argmin(abs(self._inst.times)) + self._fig.axes[0].set_xticks( + [0, min_idx, self._inst.times.size - 1]) + self._fig.axes[0].set_xticklabels( + self._inst.times[[0, min_idx, -1]].round(2)) + stc_data = self._pick_stc_image() + if self._f_idx is None: + self._fig.axes[0].set_facecolor('black') + self._stc_plot = self._fig.axes[0].plot( + stc_data[0], color='white')[0] + self._stc_vline = self._fig.axes[0].axvline( + x=self._t_idx, color='lime') + self._fig.axes[0].set_ylabel('Activation (AU)') + self._cax = None + else: + self._stc_plot = self._fig.axes[0].imshow( + stc_data, aspect='auto', cmap=self._cmap, + interpolation='bicubic') + self._stc_vline = self._fig.axes[0].axvline( + x=self._t_idx, color='lime', linewidth=0.5) + self._stc_hline = self._fig.axes[0].axhline( + y=self._f_idx, color='lime', linewidth=0.5) + self._fig.axes[0].invert_yaxis() + self._fig.axes[0].set_ylabel('Frequency (Hz)') + self._fig.axes[0].set_yticks(range(self._inst.freqs.size)) + self._fig.axes[0].set_yticklabels(self._inst.freqs.round(2)) + self._cax = self._fig.add_axes([0.88, 0.25, 0.02, 0.6]) + self._cbar = self._fig.colorbar(self._stc_plot, cax=self._cax) + self._cax.set_ylabel('Power') + self._fig.canvas.mpl_connect( + 'button_release_event', self._on_data_plot_click) + canvas.setMinimumHeight(int(self.size().height() * 0.4)) + canvas.keyPressEvent = self.keyPressEvent + return canvas + + def _plot_topomap(self): + self._topo_fig.axes[0].clear() + self._topo_cax.clear() + dtype = self._data_type_selector.currentText() + units = DEFAULTS['units'][dtype] + scaling = DEFAULTS['scalings'][dtype] + + if isinstance(self._inst, EpochsTFR): + inst_data = self._inst.data + scaling *= scaling # power is squared + units = f'({units})' + r'$^2$/Hz' + elif isinstance(self._inst, BaseEpochs): + inst_data = self._inst.get_data() + else: + inst_data = self._inst.data[None] # new axis for single epoch + + if self._epoch_idx == 'ITC': + units = 'ITC' + scaling = 1 + + pick_idx = _picks_to_idx(self._inst.info, dtype) + inst_data = inst_data[:, pick_idx] + + evo_data = self._pick_epoch(inst_data) * scaling + + if self._f_idx is not None: + evo_data = evo_data[:, self._f_idx] + + if self._baseline != 'none': + units = units if self._baseline == 'mean' else '' + evo_data = rescale( + evo_data.astype(float), times=self._inst.times, + baseline=(float(self._bl_tmin), float(self._bl_tmax)), + mode=self._baseline, copy=False) + + info = _pick_inst(self._inst, dtype, 'bads').info + ave = EvokedArray(evo_data, info, tmin=self._inst.times[0]) + + ave_max = evo_data.max() + self._ave_min = min([evo_data.min(), -ave_max]) + self._ave_range = max([ave_max, -self._ave_min]) - self._ave_min + vmin, vmax = [val / SLIDER_WIDTH * self._ave_range + self._ave_min + for val in (self._cmap_sliders[i].value() + for i in (0, 2))] + cbar_fmt = '%3.1f' if abs(evo_data).max() < 1e3 else '%.1e' + ave.plot_topomap(times=self._inst.times[self._t_idx], + scalings={dtype: 1}, units=units, + axes=(self._topo_fig.axes[0], self._topo_cax), + cmap=self._cmap, colorbar=True, cbar_fmt=cbar_fmt, + vlim=(vmin, vmax), show=False) + + self._topo_fig.axes[0].set_title('') + self._topo_fig.subplots_adjust(top=1.1, bottom=0.05, right=0.75) + self._topo_fig.canvas.draw() + + def _configure_topo_plot(self): + """Configure the plot that shows topomap.""" + from ._core import _make_mpl_plot + canvas, self._topo_fig = _make_mpl_plot( + dpi=96, hide_axes=False, facecolor='white') + self._topo_cax = self._topo_fig.add_axes([0.77, 0.1, 0.02, 0.75]) + self._plot_topomap() + canvas.setMinimumHeight(int(self.size().height() * 0.4)) + canvas.setMaximumWidth(int(self.size().width() * 0.4)) + canvas.keyPressEvent = self.keyPressEvent + return canvas + + def keyPressEvent(self, event): + """Execute functions when the user presses a key.""" + super().keyPressEvent(event) + + # update if textbox done editing + if event.key() == QtCore.Qt.Key_Return: + for widget in (self._bl_tmin_textbox, self._bl_tmax_textbox): + if widget.hasFocus(): + widget.clearFocus() + self.setFocus() # removing focus calls focus out event + + def _on_data_plot_click(self, event): + """Update viewer when the data plot is clicked on.""" + if event.inaxes is self._fig.axes[0]: + if self._f_idx is not None: + self._update = False + self.set_freq(self._inst.freqs[int(round(event.ydata))]) + self._update = True + self.set_time(self._inst.times[int(round(event.xdata))]) + self._update_intensity() + + def set_baseline(self, baseline=None, mode=None): + """Set the baseline. + + Parameters + ---------- + baseline : array-like, shape (2,) | None + The time interval to apply rescaling / baseline correction. + If None do not apply it. If baseline is (a, b) + the interval is between "a (s)" and "b (s)". + If a is None the beginning of the data is used + and if b is None then b is set to the end of the interval. + If baseline is equal to (None, None) all the time + interval is used. + mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' + Perform baseline correction by + + - subtracting the mean of baseline values ('mean') + - dividing by the mean of baseline values ('ratio') + - dividing by the mean of baseline values and taking the log + ('logratio') + - subtracting the mean of baseline values followed by dividing by + the mean of baseline values ('percent') + - subtracting the mean of baseline values and dividing by the + standard deviation of baseline values ('zscore') + - dividing by the mean of baseline values, taking the log, and + dividing by the standard deviation of log baseline values + ('zlogratio') + tmin : float + The minimum baseline time + """ # noqa E501 + _check_option('mode', mode, ('mean', 'ratio', 'logratio', 'percent', + 'zscore', 'zlogratio', 'none', None)) + self._update = False + self._baseline_selector.setCurrentText( + 'none' if mode is None else mode) + if baseline is not None: + baseline = _check_baseline(baseline, times=self._inst.times, + sfreq=self._inst.info['sfreq']) + tmin, tmax = baseline + self._bl_tmin_textbox.setText(str(tmin)) + self._bl_tmax_textbox.setText(str(tmax)) + self._update = True + self._update_stc_all() + + def _update_baseline(self, name): + """Update the chosen baseline normalization method.""" + self._baseline = name + pre_update = self._update + self._update = False + self._cmap_sliders[0].setValue(0) + self._cmap_sliders[1].setValue(SLIDER_WIDTH // 2) + self._update = pre_update + self._cmap_sliders[2].setValue(SLIDER_WIDTH) + # all baselines have negative support + self._cmap = _get_cmap('hot' if name == 'none' and self._pos_support + else 'mne') + self._go_to_extreme_button.setText( + 'Go to Max' if name == 'none' and self._pos_support else + 'Go to Extreme') + if self._update: # don't update if bl_tmin, bl_tmax are also changing + self._update_stc_all() + + def _update_baseline_tmin(self, event): + """Update tmin for the baseline.""" + try: + tmin = float(self._bl_tmin_textbox.text()) + except ValueError: + self._bl_tmin_textbox.setText(str(round(self._bl_tmin, 2))) + tmin = self._inst.times[np.clip( # find nearest time + self._inst.time_as_index(tmin, use_rounding=True)[0], + 0, self._inst.times.size - 1)] + if tmin == self._bl_tmin: + return + self._bl_tmin = tmin + if self._update: + self._update_stc_all() + + def _update_baseline_tmax(self, event): + """Update tmax for the baseline.""" + try: + tmax = float(self._bl_tmax_textbox.text()) + except ValueError: + self._bl_tmax_textbox.setText(str(round(self._bl_tmax, 2))) + return + tmax = self._inst.times[np.clip( # find nearest time + self._inst.time_as_index(tmax, use_rounding=True)[0], + 0, self._inst.times.size - 1)] + if tmax == self._bl_tmax: + return + self._bl_tmax = tmax + if self._update: + self._update_stc_all() + + def _update_data_type(self, dtype): + """Update which data type is shown in the topomap.""" + self._plot_topomap() + + def _update_data_plot_ylabel(self): + """Update the ylabel of the data plot.""" + if self._epoch_idx == 'ITC': + self._cax.set_ylabel('ITC') + elif self._is_complex: + self._cax.set_ylabel('Power') + else: + self._fig.axes[0].set_ylabel('Activation (AU)') + + def _update_epoch(self, name): + """Change which epoch is viewed.""" + self._epoch_idx = name + # handle plot labels + self._update_data_plot_ylabel() + # reset sliders + if name == 'ITC' and self._epoch_idx != 'ITC': + self._cmap_sliders[0].setValue(0) + self._cmap_sliders[1].setValue(SLIDER_WIDTH // 2) + self._cmap_sliders[2].setValue(SLIDER_WIDTH) + self._baseline_selector.setCurrentText('none') + + if self._update: + self._update_stc_all() + self._update_vectors() + + def set_freq(self, freq): + """Set the frequency to display (in Hz). + + Parameters + ---------- + freq : float + The frequency to show, in Hz. + """ + if self._f_idx is None: + raise ValueError('Source estimate does not contain frequencies') + self._freq_slider.setValue(np.argmin(abs(self._inst.freqs - freq))) + + def _update_freq(self, event=None): + """Update freq slider values.""" + self._f_idx = self._freq_slider.value() + self._freq_label.setText( + f'Freq = {self._inst.freqs[self._f_idx].round(2)} Hz') + if self._update: + self._update_stc_images() # just need volume updated here + self._stc_hline.set_ydata([self._f_idx]) + self._update_intensity() + if self._show_topomap and self._update: + self._plot_topomap() + self._fig.canvas.draw() + + def set_time(self, time): + """Set the time to display (in seconds). + + Parameters + ---------- + time : float + The time to show, in seconds. + """ + self._time_slider.setValue(np.clip( + self._inst.time_as_index(time, use_rounding=True)[0], + 0, self._inst.times.size - 1)) + + def _update_time(self, event=None): + """Update time slider values.""" + self._t_idx = self._time_slider.value() + self._time_label.setText( + f'Time = {self._inst.times[self._t_idx].round(2)} s') + if self._update: + self._update_stc_images() # just need volume updated here + self._stc_vline.set_xdata([self._t_idx]) + self._update_intensity() + if self._show_topomap and self._update: + self._plot_topomap() + self._update_vectors() + self._fig.canvas.draw() + + def set_alpha(self, alpha): + """Set the opacity of the display. + + Parameters + ---------- + alpha : float + The opacity to use. + """ + self._alpha_slider.setValue(np.clip(alpha, 0, 1)) + + def _update_alpha(self, event=None): + """Update stc plot alpha.""" + self._alpha = round(self._alpha_slider.value() / SLIDER_WIDTH, 2) + self._alpha_label.setText(f'Alpha = {self._alpha}') + for axis in range(3): + self._images['stc'][axis].set_alpha(self._alpha) + self._update_cmap() + + def set_cmap(self, vmin=None, vmid=None, vmax=None): + """Update the colormap. + + Parameters + ---------- + vmin : float + The minimum color value relative to the selected data in [0, 1]. + vmin : float + The middle color value relative to the selected data in [0, 1]. + vmin : float + The maximum color value relative to the selected data in [0, 1]. + """ + for val, name in zip((vmin, vmid, vmax), ('vmin', 'vmid', 'vmax')): + _validate_type(val, (int, float, None)) + + self._update = False + for i, val in enumerate((vmin, vmid, vmax)): + if val is not None: + _check_range(val, 0, 1, name) + self._cmap_sliders[i].setValue(int(round(val * SLIDER_WIDTH))) + self._update = True + self._update_cmap() + + def _update_cmap(self, event=None, draw=True, update_slice_plots=True, + update_3d=True): + """Update the colormap.""" + if not self._update: + return + + # no recursive updating + update_tmp = self._update + self._update = False + if self._cmap_sliders[0].value() > self._cmap_sliders[2].value(): + tmp = self._cmap_sliders[0].value() + self._cmap_sliders[0].setValue(self._cmap_sliders[2].value()) + self._cmap_sliders[2].setValue(tmp) + if self._cmap_sliders[1].value() > self._cmap_sliders[2].value(): + self._cmap_sliders[1].setValue(self._cmap_sliders[2].value()) + if self._cmap_sliders[1].value() < self._cmap_sliders[0].value(): + self._cmap_sliders[1].setValue(self._cmap_sliders[0].value()) + self._update = update_tmp + + vmin, vmid, vmax = [ + val / SLIDER_WIDTH * self._stc_range + self._stc_min + for val in (self._cmap_sliders[i].value() for i in range(3))] + mid_pt = (vmid - vmin) / (vmax - vmin) + ctable = self._cmap(np.concatenate([ + np.linspace(0, mid_pt, 128), np.linspace(mid_pt, 1, 128)])) + cmap = LinearSegmentedColormap.from_list('stc', ctable.tolist(), N=256) + ctable = np.round(ctable * 255.0).astype(np.uint8) + if self._stc_min < 0: # make center values transparent + zero_pt = np.argmin(abs(np.linspace(vmin, vmax, 256))) + # 31 on either side of the zero point are made transparent + ctable[max([zero_pt - 31, 0]):min([zero_pt + 32, 255]), 3] = 0 + else: # make low values transparent + ctable[:25, 3] = np.linspace(0, 255, 25) + + for axis in range(3): + self._images['stc'][axis].set_clim(vmin, vmax) + self._images['stc'][axis].set_cmap(cmap) + if draw and self._update: + self._figs[axis].canvas.draw() + + # update nans in slice plot image + if update_slice_plots and self._update: + self._update_stc_volume() + self._plot_stc_images(draw=draw) + + if self._f_idx is None: + self._fig.axes[0].set_ylim( + [self._stc_min, self._stc_min + self._stc_range]) + else: + self._stc_plot.set_clim(vmin, vmax) + self._stc_plot.set_cmap(cmap) + # update colorbar + self._cax.clear() + self._cbar = self._fig.colorbar(self._stc_plot, cax=self._cax) + self._update_data_plot_ylabel() + + if self._show_topomap: + topo_vmin, topo_vmax = [ + val / SLIDER_WIDTH * self._ave_range + self._ave_min + for val in (self._cmap_sliders[i].value() for i in (0, 2))] + self._topo_fig.axes[0].get_images()[0].set_clim( + topo_vmin, topo_vmax) + if draw and self._update: + self._topo_fig.canvas.draw() + + if draw and self._update: + self._fig.canvas.draw() + + if not update_3d: + return + + if self._data.shape[2] > 1 and not self._is_complex: + # update vector mask + self._update_vector_threshold() + self._plot_vectors(draw=False) + self._renderer._set_colormap_range( + actor=self._vector_actor, ctable=ctable, scalar_bar=None, + rng=[0, VECTOR_SCALAR]) + + # set alpha + ctable[ctable[:, 3] > self._alpha * 255, 3] = self._alpha * 255 + self._renderer._set_volume_range(self._volume_pos, ctable, self._alpha, + self._scalar_bar, [vmin, vmax]) + self._renderer._set_volume_range(self._volume_neg, ctable, self._alpha, + self._scalar_bar, [vmin, vmax]) + if draw and self._update: + self._renderer._update() + + def go_to_extreme(self): + """Go to the extreme intensity source vertex.""" + stc_idx, f_idx, t_idx = np.unravel_index(np.nanargmax( + abs(self._stc_data_vol)), self._stc_data_vol.shape) + if self._f_idx is not None: + self._freq_slider.setValue(f_idx) + self._time_slider.setValue(t_idx) + max_coord = np.array(np.where(self._src_lut == stc_idx)).flatten() + max_coord_mri = _coord_to_coord( + max_coord, self._src_vox_scan_ras_t, self._scan_ras_vox_t) + self._set_ras(apply_trans(self._vox_ras_t, max_coord_mri)) + + def _plot_data(self, draw=True): + """Update which coordinate's data is being shown.""" + stc_data = self._pick_stc_image() + if self._f_idx is None: # no freq data + self._stc_plot.set_ydata(stc_data[0]) + else: + self._stc_plot.set_data(stc_data) + if draw and self._update: + self._fig.canvas.draw() + + def _toggle_interp(self): + """Toggle interpolating the spectrogram data plot.""" + if self._interp_button.text() == 'Off': + self._interp_button.setText('On') + self._interp_button.setStyleSheet("background-color: green") + else: # text == 'On', turn off + self._interp_button.setText('Off') + self._interp_button.setStyleSheet("background-color: red") + + self._stc_plot.set_interpolation( + 'bicubic' if self._interp_button.text() == 'On' else None) + if self._update: + self._fig.canvas.draw() + # draws data plot, fixes vmin, vmax + self._update_cmap(update_slice_plots=False, update_3d=False) + + def _update_intensity(self): + """Update the intensity label.""" + label_str = '{:.3f}' + if self._stc_range > 1e5: + label_str = '{:.3e}' + elif np.issubdtype(self._stc_img.dtype, np.integer): + label_str = '{:d}' + self._intensity_label.setText( + ('intensity = ' + label_str).format( + self._stc_img[tuple(self._get_src_coord())])) + + def _update_moved(self): + """Update when cursor position changes.""" + super()._update_moved() + self._update_intensity() + + @fill_doc + def set_3d_view(self, roll=None, distance=None, azimuth=None, + elevation=None, focalpoint=None): + """Orient camera to display view. + + Parameters + ---------- + %(roll)s + %(distance)s + %(azimuth)s + %(elevation)s + %(focalpoint)s + """ + self._renderer.set_camera( + roll=roll, distance=distance, azimuth=azimuth, + elevation=elevation, focalpoint=focalpoint, reset_camera=False) + self._renderer._update() + + def _plot_vectors(self, draw=True): + """Update the vector plots.""" + if self._data.shape[2] > 1 and not self._is_complex: + self._vector_data.point_data['vec'] = \ + VECTOR_SCALAR * self._stc_vectors_masked + if draw and self._update: + self._renderer._update() + + def _update_stc_images(self, draw=True): + """Update the stc image based on the time and frequency range.""" + self._update_stc_volume() + self._plot_stc_images(draw=draw) + self._plot_3d_stc(draw=draw) + + def _plot_3d_stc(self, draw=True): + """Update the 3D rendering.""" + self._plot_vectors(draw=False) + self._grid.cell_data['values'] = np.where( + np.isnan(self._stc_img), 0., self._stc_img).flatten(order='F') + if draw and self._update: + self._renderer._update() + + def _plot_stc_images(self, axis=None, draw=True): + """Update the stc image(s).""" + src_coord = self._get_src_coord() + for axis in range(3): + # ensure in bounds + if src_coord[axis] >= 0 and \ + src_coord[axis] < self._stc_img.shape[axis]: + stc_slice = np.take( + self._stc_img, src_coord[axis], axis=axis).T + else: + stc_slice = np.take(self._stc_img, 0, axis=axis).T * np.nan + self._images['stc'][axis].set_data(stc_slice) + if draw and self._update: + self._draw(axis) + + def _update_images(self, axis=None, draw=True): + """Update images when general changes happen.""" + self._plot_stc_images(axis=axis, draw=draw) + self._plot_data(draw=draw) + super()._update_images() diff --git a/mne/gui/tests/test_ieeg_locate.py b/mne/gui/tests/test_ieeg_locate.py index 7fb2c544066..4f001226f26 100644 --- a/mne/gui/tests/test_ieeg_locate.py +++ b/mne/gui/tests/test_ieeg_locate.py @@ -82,6 +82,7 @@ def test_ieeg_elec_locate_io(renderer_interactive_pyvistaqt): mne.gui.locate_ieeg(info, trans, aligned_ct, subject, subjects_dir) +@pytest.mark.allow_unclosed_pyside2 @requires_version('sphinx_gallery') @testing.requires_testing_data def test_locate_scraper(renderer_interactive_pyvistaqt, _fake_CT_coords, @@ -115,6 +116,7 @@ def test_locate_scraper(renderer_interactive_pyvistaqt, _fake_CT_coords, # no need to call .close +@pytest.mark.allow_unclosed_pyside2 @testing.requires_testing_data def test_ieeg_elec_locate_display(renderer_interactive_pyvistaqt, _fake_CT_coords): diff --git a/mne/gui/tests/test_vol_stc.py b/mne/gui/tests/test_vol_stc.py new file mode 100644 index 00000000000..c2552507e43 --- /dev/null +++ b/mne/gui/tests/test_vol_stc.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- +# Authors: Alex Rockhill +# +# License: BSD-3-clause + +import sys +import numpy as np +from numpy.testing import assert_allclose + +import pytest + +import mne +from mne.datasets import testing +from mne.io.constants import FIFF +from mne.viz.utils import _fake_click + +data_path = testing.data_path(download=False) +subject = 'sample' +subjects_dir = data_path / 'subjects' +fname_raw = data_path / 'MEG' / 'sample' / 'sample_audvis_trunc_raw.fif' +fname_fwd_vol = data_path / 'MEG' / 'sample' / \ + 'sample_audvis_trunc-meg-vol-7-fwd.fif' +fname_fwd = \ + data_path / 'MEG' / 'sample' / 'sample_audvis_trunc-meg-eeg-oct-4-fwd.fif' + +# TO DO: remove when Azure fixed, causes +# 'Windows fatal exception: access violation' +# but fails to replicate locally on a Windows machine +if sys.platform == 'win32': + pytest.skip('Azure CI problem on Windows', allow_module_level=True) + + +def _fake_stc(src_type='vol'): + """Fake a 5D source time estimate.""" + rng = np.random.default_rng(11) + n_epochs = 3 + info = mne.io.read_info(fname_raw) + info = mne.pick_info(info, mne.pick_types(info, meg='grad')) + if src_type == 'vol': + src = mne.setup_volume_source_space( + subject='sample', subjects_dir=subjects_dir, + mri='aseg.mgz', volume_label='Left-Cerebellum-Cortex', + pos=20, add_interpolator=False) + else: + assert src_type == 'surf' + forward = mne.read_forward_solution(fname_fwd) + src = forward['src'] + for this_src in src: + this_src['coord_frame'] = FIFF.FIFFV_COORD_MRI + this_src['subject_his_id'] = 'sample' + freqs = np.arange(8, 10) + times = np.arange(0.1, 0.11, 1 / info['sfreq']) + data = rng.integers(-1000, 1000, size=(n_epochs, len(info.ch_names), + freqs.size, times.size)) + \ + 1j * rng.integers(-1000, 1000, size=(n_epochs, len(info.ch_names), + freqs.size, times.size)) + epochs_tfr = mne.time_frequency.EpochsTFR( + info, data, times=times, freqs=freqs) + nuse = sum([this_src['nuse'] for this_src in src]) + stc_data = rng.integers(-1000, 1000, size=(n_epochs, nuse, 3, + freqs.size, times.size)) + \ + 1j * rng.integers(-1000, 1000, size=(n_epochs, nuse, 3, + freqs.size, times.size)) + return stc_data, src, epochs_tfr + + +@pytest.mark.allow_unclosed_pyside2 +def test_stc_viewer_io(renderer_interactive_pyvistaqt): + """Test the input/output of the stc viewer GUI.""" + pytest.importorskip('nibabel') + pytest.importorskip('dipy') + from mne.gui._vol_stc import VolSourceEstimateViewer + + stc_data, src, epochs_tfr = _fake_stc() + with pytest.raises(NotImplementedError, + match='surface source estimate ' + 'viewing is not yet supported'): + VolSourceEstimateViewer(stc_data, inst=epochs_tfr) + with pytest.raises(NotImplementedError, match='source estimate object'): + VolSourceEstimateViewer(stc_data, src=src) + with pytest.raises(ValueError, match='`data` must be an array'): + VolSourceEstimateViewer('foo', subject='sample', + subjects_dir=subjects_dir, + src=src, inst=epochs_tfr) + with pytest.raises(ValueError, + match='Number of epochs in `inst` does not match'): + VolSourceEstimateViewer(stc_data[1:], src=src, inst=epochs_tfr) + with pytest.raises(RuntimeError, + match='ource vertices in `data` do not match '): + VolSourceEstimateViewer(stc_data[:, :1], subject='sample', + subjects_dir=subjects_dir, + src=src, inst=epochs_tfr) + src[0]['coord_frame'] = FIFF.FIFFV_COORD_HEAD + with pytest.raises(RuntimeError, match='must be in the `mri`'): + VolSourceEstimateViewer(stc_data, subject='sample', + subjects_dir=subjects_dir, + src=src, inst=epochs_tfr) + src[0]['coord_frame'] = FIFF.FIFFV_COORD_MRI + + src[0]['subject_his_id'] = 'foo' + with pytest.raises(RuntimeError, match='Source space subject'): + with pytest.warns(RuntimeWarning, match='`pial` surface not found'): + VolSourceEstimateViewer(stc_data, subject='sample', + subjects_dir=subjects_dir, + src=src, inst=epochs_tfr) + + with pytest.raises(ValueError, + match='Frequencies in `inst` do not match'): + VolSourceEstimateViewer( + stc_data[:, :, :, 1:], src=src, inst=epochs_tfr) + + with pytest.raises(ValueError, match='Complex data is required'): + VolSourceEstimateViewer(stc_data.real, src=src, inst=epochs_tfr) + + with pytest.raises(ValueError, + match='Times in `inst` do not match'): + VolSourceEstimateViewer( + stc_data[:, :, :, :, 1:], src=src, inst=epochs_tfr) + + +@pytest.mark.allow_unclosed_pyside2 +@testing.requires_testing_data +def test_stc_viewer_display(renderer_interactive_pyvistaqt): + """Test that the stc viewer GUI displays properly.""" + pytest.importorskip('nibabel') + pytest.importorskip('dipy') + from mne.gui._vol_stc import VolSourceEstimateViewer + + stc_data, src, epochs_tfr = _fake_stc() + with pytest.warns(RuntimeWarning, match='`pial` surface not found'): + viewer = VolSourceEstimateViewer( + stc_data, subject='sample', subjects_dir=subjects_dir, + src=src, inst=epochs_tfr) + # test go to max + viewer._go_to_extreme_button.click() + assert_allclose(viewer._ras, [-20, -60, -20], atol=0.01) + + src_coord = viewer._get_src_coord() + stc_idx = viewer._src_lut[src_coord] + + viewer._epoch_selector.setCurrentText('Epoch 0') + assert viewer._epoch_idx == 'Epoch 0' + + viewer._freq_slider.setValue(1) + assert viewer._f_idx == 1 + + viewer._time_slider.setValue(2) + assert viewer._t_idx == 2 + + plot_data = np.linalg.norm((stc_data[0] * stc_data[0].conj()).real, + axis=1)[stc_idx] + assert_allclose(plot_data, viewer._stc_plot.get_array()) + + # test clicking on stc plot + _fake_click(viewer._fig, viewer._fig.axes[0], + (0, 0), xform='data', kind='release') + assert viewer._t_idx == 0 + assert viewer._f_idx == 0 + + # test baseline + for mode in ('zscore', 'ratio'): + viewer.set_baseline((0.1, None), mode) + + # done with time-frequency, close + viewer.close() + + # test time only, not frequencies + epochs = mne.EpochsArray(epochs_tfr.data[:, :, 0].real, epochs_tfr.info, + tmin=epochs_tfr.tmin) + stc_time_data = stc_data[:, :, :, 0:1].real + with pytest.warns(RuntimeWarning, match='`pial` surface not found'): + viewer = VolSourceEstimateViewer( + stc_time_data, subject='sample', + subjects_dir=subjects_dir, src=src, inst=epochs) + + # test go to max + viewer._go_to_extreme_button.click() + assert_allclose(viewer._ras, [-20, -60, -20], atol=0.01) + + src_coord = viewer._get_src_coord() + stc_idx = viewer._src_lut[src_coord] + + viewer._epoch_selector.setCurrentText('Epoch 0') + assert viewer._epoch_idx == 'Epoch 0' + + with pytest.raises(ValueError, match='Source estimate does ' + 'not contain frequencies'): + viewer.set_freq(10) + + viewer._time_slider.setValue(2) + assert viewer._t_idx == 2 + + assert_allclose(np.linalg.norm(stc_time_data[0], axis=1)[stc_idx][0], + viewer._stc_plot.get_data()[1]) + viewer.close() + + +@testing.requires_testing_data +def test_stc_viewer_surface(renderer_interactive_pyvistaqt): + """Test the stc viewer with a surface source space.""" + pytest.importorskip('nibabel') + pytest.importorskip('dipy') + from mne.gui._vol_stc import VolSourceEstimateViewer + stc_data, src, epochs_tfr = _fake_stc(src_type='surf') + with pytest.raises(RuntimeError, match='not implemented yet'): + VolSourceEstimateViewer( + stc_data, subject='sample', + subjects_dir=subjects_dir, src=src, inst=epochs_tfr) diff --git a/mne/utils/__init__.py b/mne/utils/__init__.py index 0d04c882783..2ceef298cae 100644 --- a/mne/utils/__init__.py +++ b/mne/utils/__init__.py @@ -11,7 +11,7 @@ _check_pandas_index_arguments, _check_event_id, _check_ch_locs, _check_compensation_grade, _check_if_nan, _is_numeric, _ensure_int, _check_preload, - _validate_type, _check_info_inv, + _validate_type, _check_range, _check_info_inv, _check_channels_spatial_filter, _check_one_ch_type, _check_rank, _check_option, _check_depth, _check_combine, _path_like, _check_src_normal, _check_stc_units, diff --git a/mne/utils/check.py b/mne/utils/check.py index cb4459e9e26..780458264c9 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -551,6 +551,37 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, f"got {type(item)} instead.") +def _check_range(val, min_val, max_val, name, min_inclusive=True, + max_inclusive=True): + """Check that item is within range. + + Parameters + ---------- + val : int | float + The value to be checked. + min_val : int | float + The minimum value allowed. + max_val : int | float + The maximum value allowed. + name : str + The name of the value. + min_inclusive : bool + Whether ``val`` is allowed to be ``min_val``. + max_inclusive : bool + Whether ``val`` is allowed to be ``max_val``. + """ + below_min = val < min_val if min_inclusive else val <= min_val + above_max = val > max_val if max_inclusive else val >= max_val + if below_min or above_max: + error_str = f'The value of {name} must be between {min_val} ' + if min_inclusive: + error_str += 'inclusive ' + error_str += f'and {max_val}' + if max_inclusive: + error_str += 'inclusive ' + raise ValueError(error_str) + + def _path_like(item): """Validate that `item` is `path-like`. diff --git a/mne/utils/tests/test_check.py b/mne/utils/tests/test_check.py index 8f28ee7799a..44caa61ba10 100644 --- a/mne/utils/tests/test_check.py +++ b/mne/utils/tests/test_check.py @@ -18,7 +18,8 @@ from mne.utils import (check_random_state, _check_fname, check_fname, _suggest, _check_subject, _check_info_inv, _check_option, Bunch, check_version, _path_like, _validate_type, _on_missing, - _safe_input, _check_ch_locs, _check_sphere) + _safe_input, _check_ch_locs, _check_sphere, + _check_range) data_path = testing.data_path(download=False) base_dir = data_path / "MEG" / "sample" @@ -184,6 +185,15 @@ def test_validate_type(): _validate_type(False, 'int-like') +def test_check_range(): + """Test _check_range.""" + _check_range(10, 1, 100, 'value') + with pytest.raises(ValueError, match='must be between'): + _check_range(0, 1, 10, 'value') + with pytest.raises(ValueError, match='must be between'): + _check_range(1, 1, 10, 'value', False, False) + + @testing.requires_testing_data def test_suggest(): """Test suggestions.""" diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index be10dbe9502..0db592fe14c 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -908,7 +908,6 @@ def _plot_topomap( border=_BORDER_DEFAULT, res=64, cmap=None, vmin=None, vmax=None, cnorm=None, show=True, onselect=None): from matplotlib.colors import Normalize - import matplotlib.pyplot as plt from matplotlib.widgets import RectangleSelector data = np.asarray(data) logger.debug(f'Plotting topomap for {ch_type} data shape {data.shape}') @@ -1050,7 +1049,7 @@ def _plot_topomap( verticalalignment='center', size='x-small') if not axes.figure.get_constrained_layout(): - plt.subplots_adjust(top=.95) + axes.figure.subplots_adjust(top=.95) if onselect is not None: lim = axes.dataLim