From 7aa3a27631b16fc9c16852c720e7243d876fe104 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Sat, 27 Aug 2022 00:44:52 +0100 Subject: [PATCH 1/7] ENH: Add temperature and galvanic (#11090) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ENH: Add temperature and galvanic * FIX: Use correct name * FIX: Flake * FIX: Rename * DOC: Sp * Update doc/_includes/channel_types.rst Co-authored-by: Richard Höchenberger Co-authored-by: Richard Höchenberger --- doc/_includes/channel_types.rst | 4 ++++ doc/changes/latest.inc | 1 + mne/channels/channels.py | 12 +++++++----- mne/defaults.py | 16 ++++++++++------ mne/io/constants.py | 8 ++++++-- mne/io/pick.py | 29 ++++++++++++++++++++--------- mne/io/tests/test_constants.py | 2 +- mne/utils/docs.py | 4 ++++ mne/viz/raw.py | 2 +- 9 files changed, 54 insertions(+), 24 deletions(-) diff --git a/doc/_includes/channel_types.rst b/doc/_includes/channel_types.rst index 6c0adc18c3d..647dab25ba4 100644 --- a/doc/_includes/channel_types.rst +++ b/doc/_includes/channel_types.rst @@ -65,4 +65,8 @@ ias Internal Active Shielding data syst System status channel information (Triux systems only) + +temperature Temperature Degrees Celsius + +gsr Galvanic skin response Siemens ============= ========================================= ================= diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 05b238a21b5..884aadf9cfa 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -33,6 +33,7 @@ Enhancements - Add ``starting_affine`` keyword argument to :func:`mne.transforms.compute_volume_registration` to initialize an alignment with an affine (:gh:`11020` by `Alex Rockhill`_) - The ``trans`` parameter in :func:`mne.make_field_map` now accepts a :class:`~pathlib.Path` object, and uses standardised loading logic (:gh:`10784` by :newcontrib:`Andrew Quinn`) - Add HTML representation for `~mne.Evoked` in Jupyter Notebooks (:gh:`11075` by `Valerii Chirkov`_ and `Andrew Quinn`_) +- Add support for ``temperature`` and ``gsr`` (galvanic skin response, i.e., electrodermal activity) channel types (:gh:`11090` by `Eric Larson`_) - Allow :func:`mne.beamformer.make_dics` to take ``pick_ori='vector'`` to compute vector source estimates (:gh:`19080` by `Alex Rockhill`_) - Add ``on_missing`` functionality to all of our classes that have a ``drop_channels`` method, to control what happens when channel names are not in the object (:gh:`11077` by `Andrew Quinn`_) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index d1261ee4b9e..f2f22fadc7e 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -328,7 +328,7 @@ def set_channel_types(self, mapping, verbose=None): ecg, eeg, emg, eog, exci, ias, misc, resp, seeg, dbs, stim, syst, ecog, hbo, hbr, fnirs_cw_amplitude, fnirs_fd_ac_amplitude, - fnirs_fd_phase, fnirs_od + fnirs_fd_phase, fnirs_od, temperature, gsr .. versionadded:: 0.9.0 """ @@ -590,11 +590,12 @@ class UpdateChannelsMixin(object): @verbose def pick_types(self, meg=False, eeg=False, stim=False, eog=False, - ecg=False, emg=False, ref_meg='auto', misc=False, + ecg=False, emg=False, ref_meg='auto', *, misc=False, resp=False, chpi=False, exci=False, ias=False, syst=False, seeg=False, dipole=False, gof=False, bio=False, - ecog=False, fnirs=False, csd=False, dbs=False, include=(), - exclude='bads', selection=None, verbose=None): + ecog=False, fnirs=False, csd=False, dbs=False, + temperature=False, gsr=False, + include=(), exclude='bads', selection=None, verbose=None): """Pick some channels by type and names. Parameters @@ -620,7 +621,8 @@ def pick_types(self, meg=False, eeg=False, stim=False, eog=False, ref_meg=ref_meg, misc=misc, resp=resp, chpi=chpi, exci=exci, ias=ias, syst=syst, seeg=seeg, dipole=dipole, gof=gof, bio=bio, ecog=ecog, fnirs=fnirs, csd=csd, dbs=dbs, include=include, - exclude=exclude, selection=selection) + exclude=exclude, selection=selection, temperature=temperature, + gsr=gsr) self._pick_drop_channels(idx) diff --git a/mne/defaults.py b/mne/defaults.py index 27bfe96b26c..0340170f994 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -12,25 +12,27 @@ exci='k', ias='k', syst='k', seeg='saddlebrown', dbs='seagreen', dipole='k', gof='k', bio='k', ecog='k', hbo='#AA3377', hbr='b', fnirs_cw_amplitude='k', fnirs_fd_ac_amplitude='k', - fnirs_fd_phase='k', fnirs_od='k', csd='k', whitened='k'), + fnirs_fd_phase='k', fnirs_od='k', csd='k', whitened='k', + gsr='#666633', temperature='#663333'), si_units=dict(mag='T', grad='T/m', eeg='V', eog='V', ecg='V', emg='V', misc='AU', seeg='V', dbs='V', dipole='Am', gof='GOF', bio='V', ecog='V', hbo='M', hbr='M', ref_meg='T', fnirs_cw_amplitude='V', fnirs_fd_ac_amplitude='V', fnirs_fd_phase='rad', fnirs_od='V', csd='V/m²', - whitened='Z'), + whitened='Z', gsr='S', temperature='C'), units=dict(mag='fT', grad='fT/cm', eeg='µV', eog='µV', ecg='µV', emg='µV', misc='AU', seeg='mV', dbs='µV', dipole='nAm', gof='GOF', bio='µV', ecog='µV', hbo='µM', hbr='µM', ref_meg='fT', fnirs_cw_amplitude='V', fnirs_fd_ac_amplitude='V', fnirs_fd_phase='rad', fnirs_od='V', csd='mV/m²', - whitened='Z'), + whitened='Z', gsr='S', temperature='C'), # scalings for the units scalings=dict(mag=1e15, grad=1e13, eeg=1e6, eog=1e6, emg=1e6, ecg=1e6, misc=1.0, seeg=1e3, dbs=1e6, ecog=1e6, dipole=1e9, gof=1.0, bio=1e6, hbo=1e6, hbr=1e6, ref_meg=1e15, fnirs_cw_amplitude=1.0, fnirs_fd_ac_amplitude=1.0, - fnirs_fd_phase=1., fnirs_od=1.0, csd=1e3, whitened=1.), + fnirs_fd_phase=1., fnirs_od=1.0, csd=1e3, whitened=1., + gsr=1., temperature=1.), # rough guess for a good plot scalings_plot_raw=dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc='auto', @@ -39,7 +41,8 @@ hbr=10e-6, whitened=10., fnirs_cw_amplitude=2e-2, fnirs_fd_ac_amplitude=2e-2, fnirs_fd_phase=2e-1, fnirs_od=2e-2, csd=200e-4, - dipole=1e-7, gof=1e2), + dipole=1e-7, gof=1e2, + gsr=1., temperature=1.), scalings_cov_rank=dict(mag=1e12, grad=1e11, eeg=1e5, # ~100x scalings seeg=1e1, dbs=1e4, ecog=1e4, hbo=1e4, hbr=1e4), ylim=dict(mag=(-600., 600.), grad=(-200., 200.), eeg=(-200., 200.), @@ -55,7 +58,8 @@ fnirs_fd_phase='fNIRS (FD phase)', fnirs_od='fNIRS (OD)', hbr='Deoxyhemoglobin', gof='Goodness of fit', csd='Current source density', - stim='Stimulus', + stim='Stimulus', gsr='Galvanic skin response', + temperature='Temperature', ), mask_params=dict(marker='o', markerfacecolor='w', diff --git a/mne/io/constants.py b/mne/io/constants.py index 9f3959004b0..e37204e36a0 100644 --- a/mne/io/constants.py +++ b/mne/io/constants.py @@ -202,6 +202,8 @@ FIFF.FIFFV_DIPOLE_WAVE = 1000 # Dipole time curve (xplotter/xfit) FIFF.FIFFV_GOODNESS_FIT = 1001 # Goodness of fit (xplotter/xfit) FIFF.FIFFV_FNIRS_CH = 1100 # Functional near-infrared spectroscopy +FIFF.FIFFV_TEMPERATURE_CH = 1200 # Functional near-infrared spectroscopy +FIFF.FIFFV_GALVANIC_CH = 1300 # Galvanic skin response _ch_kind_named = {key: key for key in ( FIFF.FIFFV_BIO_CH, FIFF.FIFFV_MEG_CH, @@ -223,6 +225,8 @@ FIFF.FIFFV_DIPOLE_WAVE, FIFF.FIFFV_GOODNESS_FIT, FIFF.FIFFV_FNIRS_CH, + FIFF.FIFFV_GALVANIC_CH, + FIFF.FIFFV_TEMPERATURE_CH, )} # @@ -839,7 +843,7 @@ FIFF.FIFF_UNIT_V = 107 # volt FIFF.FIFF_UNIT_F = 108 # farad FIFF.FIFF_UNIT_OHM = 109 # ohm -FIFF.FIFF_UNIT_MHO = 110 # one per ohm +FIFF.FIFF_UNIT_S = 110 # Siemens (same as Moh, what fiff-constants calls it) FIFF.FIFF_UNIT_WB = 111 # weber FIFF.FIFF_UNIT_T = 112 # tesla FIFF.FIFF_UNIT_H = 113 # Henry @@ -861,7 +865,7 @@ FIFF.FIFF_UNIT_CD, FIFF.FIFF_UNIT_MOL_M3, FIFF.FIFF_UNIT_HZ, FIFF.FIFF_UNIT_N, FIFF.FIFF_UNIT_PA, FIFF.FIFF_UNIT_J, FIFF.FIFF_UNIT_W, FIFF.FIFF_UNIT_C, FIFF.FIFF_UNIT_V, FIFF.FIFF_UNIT_F, FIFF.FIFF_UNIT_OHM, - FIFF.FIFF_UNIT_MHO, FIFF.FIFF_UNIT_WB, FIFF.FIFF_UNIT_T, FIFF.FIFF_UNIT_H, + FIFF.FIFF_UNIT_S, FIFF.FIFF_UNIT_WB, FIFF.FIFF_UNIT_T, FIFF.FIFF_UNIT_H, FIFF.FIFF_UNIT_CEL, FIFF.FIFF_UNIT_LM, FIFF.FIFF_UNIT_LX, FIFF.FIFF_UNIT_V_M2, FIFF.FIFF_UNIT_T_M, FIFF.FIFF_UNIT_AM, FIFF.FIFF_UNIT_AM_M2, FIFF.FIFF_UNIT_AM_M3, diff --git a/mne/io/pick.py b/mne/io/pick.py index 548e3a913dc..2d33cb6d7a6 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -94,7 +94,12 @@ def get_channel_type_constants(include_defaults=False): coil_type=FIFF.FIFFV_COIL_FNIRS_HBR), csd=dict(kind=FIFF.FIFFV_EEG_CH, unit=FIFF.FIFF_UNIT_V_M2, - coil_type=FIFF.FIFFV_COIL_EEG_CSD)) + coil_type=FIFF.FIFFV_COIL_EEG_CSD), + temperature=dict(kind=FIFF.FIFFV_TEMPERATURE_CH, + unit=FIFF.FIFF_UNIT_C), + gsr=dict(kind=FIFF.FIFFV_GALVANIC_CH, + unit=FIFF.FIFF_UNIT_S), + ) if include_defaults: coil_none = dict(coil_type=FIFF.FIFFV_COIL_NONE) unit_none = dict(unit=FIFF.FIFF_UNIT_NONE) @@ -146,6 +151,8 @@ def get_channel_type_constants(include_defaults=False): FIFF.FIFFV_GOODNESS_FIT: 'gof', FIFF.FIFFV_ECOG_CH: 'ecog', FIFF.FIFFV_FNIRS_CH: 'fnirs', + FIFF.FIFFV_TEMPERATURE_CH: 'temperature', + FIFF.FIFFV_GALVANIC_CH: 'gsr', } # How to reduce our categories in channel_type (originally) _second_rules = { @@ -186,7 +193,8 @@ def channel_type(info, idx): {'grad', 'mag', 'eeg', 'csd', 'stim', 'eog', 'emg', 'ecg', 'ref_meg', 'resp', 'exci', 'ias', 'syst', 'misc', 'seeg', 'dbs', - 'bio', 'chpi', 'dipole', 'gof', 'ecog', 'hbo', 'hbr'} + 'bio', 'chpi', 'dipole', 'gof', 'ecog', 'hbo', 'hbr', + 'temperature', 'gsr'} """ # This is faster than the original _channel_type_old now in test_pick.py # because it uses (at most!) two dict lookups plus one conditional @@ -368,10 +376,11 @@ def _check_info_exclude(info, exclude): @fill_doc def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, - emg=False, ref_meg='auto', misc=False, resp=False, chpi=False, - exci=False, ias=False, syst=False, seeg=False, dipole=False, - gof=False, bio=False, ecog=False, fnirs=False, csd=False, - dbs=False, include=(), exclude='bads', selection=None): + emg=False, ref_meg='auto', *, misc=False, resp=False, + chpi=False, exci=False, ias=False, syst=False, seeg=False, + dipole=False, gof=False, bio=False, ecog=False, fnirs=False, + csd=False, dbs=False, temperature=False, gsr=False, + include=(), exclude='bads', selection=None): """Pick channels by type and names. Parameters @@ -399,7 +408,8 @@ def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, len(info['comps']) > 0 and meg is not False) for param in (eeg, stim, eog, ecg, emg, misc, resp, chpi, exci, - ias, syst, seeg, dipole, gof, bio, ecog, csd, dbs): + ias, syst, seeg, dipole, gof, bio, ecog, csd, dbs, + temperature, gsr): if not isinstance(param, bool): w = ('Parameters for all channel types (with the exception of ' '"meg", "ref_meg" and "fnirs") must be of type bool, not {}.') @@ -408,7 +418,8 @@ def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, param_dict = dict(eeg=eeg, stim=stim, eog=eog, ecg=ecg, emg=emg, misc=misc, resp=resp, chpi=chpi, exci=exci, ias=ias, syst=syst, seeg=seeg, dbs=dbs, dipole=dipole, - gof=gof, bio=bio, ecog=ecog, csd=csd) + gof=gof, bio=bio, ecog=ecog, csd=csd, + temperature=temperature, gsr=gsr) # avoid triage if possible if isinstance(meg, bool): for key in ('grad', 'mag'): @@ -911,7 +922,7 @@ def _check_excludes_includes(chs, info=None, allow_bads=False): meg=True, eeg=True, csd=True, stim=False, eog=False, ecg=False, emg=False, misc=False, resp=False, chpi=False, exci=False, ias=False, syst=False, seeg=True, dipole=False, gof=False, bio=False, ecog=True, fnirs=True, - dbs=True) + dbs=True, temperature=False, gsr=False) _PICK_TYPES_KEYS = tuple(list(_PICK_TYPES_DATA_DICT) + ['ref_meg']) _MEG_CH_TYPES_SPLIT = ('mag', 'grad', 'planar1', 'planar2') _FNIRS_CH_TYPES_SPLIT = ('hbo', 'hbr', 'fnirs_cw_amplitude', diff --git a/mne/io/tests/test_constants.py b/mne/io/tests/test_constants.py index b74c4ec3894..b334447993d 100644 --- a/mne/io/tests/test_constants.py +++ b/mne/io/tests/test_constants.py @@ -21,7 +21,7 @@ # https://github.com/mne-tools/fiff-constants/commits/master REPO = 'mne-tools' -COMMIT = 'aa49e20cff5791fbaf01d77ad4ec2e0ecb69840d' +COMMIT = '6d9ca9ce7fb44c63d429c2986a953500743dfb22' # These are oddities that we won't address: iod_dups = (355, 359) # these are in both MEGIN and MNE files diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 8f89057ad77..8181b856a44 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2357,6 +2357,10 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): EEG-CSD channels. dbs : bool Deep brain stimulation channels. +temperature : bool + Temperature channels. +gsr : bool + Galvanic skin response channels. include : list of str List of additional channels to include. If empty do not include any. diff --git a/mne/viz/raw.py b/mne/viz/raw.py index a7ffa6e875c..8839654912f 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -565,7 +565,7 @@ def _setup_channel_selections(raw, kind, order): ecg=True, emg=True, ref_meg=False, misc=True, resp=True, chpi=True, exci=True, ias=True, syst=True, seeg=False, bio=True, ecog=False, fnirs=False, dbs=False, - exclude=()) + temperature=True, gsr=True, exclude=()) if len(misc) and np.in1d(misc, order).any(): selections_dict['Misc'] = misc return selections_dict From 6cebb1332a5e33c9a4e4d6999625a4cc585803e1 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Sat, 27 Aug 2022 09:40:46 +0100 Subject: [PATCH 2/7] ENH : add units parameter to read_raw_edf in case units is missing from the file (#11099) * ENH : add units parameter to read_raw_edf in case units is missing from the file * cleanup * update what's new + flake8 * adding units to bdf too * fix docstring --- doc/changes/latest.inc | 1 + mne/io/edf/edf.py | 31 +++++++++++++++++++++++++------ mne/utils/docs.py | 8 ++++++++ 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 884aadf9cfa..b7bb6013021 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -35,6 +35,7 @@ Enhancements - Add HTML representation for `~mne.Evoked` in Jupyter Notebooks (:gh:`11075` by `Valerii Chirkov`_ and `Andrew Quinn`_) - Add support for ``temperature`` and ``gsr`` (galvanic skin response, i.e., electrodermal activity) channel types (:gh:`11090` by `Eric Larson`_) - Allow :func:`mne.beamformer.make_dics` to take ``pick_ori='vector'`` to compute vector source estimates (:gh:`19080` by `Alex Rockhill`_) +- Add ``units`` parameter to :func:`mne.io.read_raw_edf` in case units are missing from the file (:gh:`11099` by `Alex Gramfort`_) - Add ``on_missing`` functionality to all of our classes that have a ``drop_channels`` method, to control what happens when channel names are not in the object (:gh:`11077` by `Andrew Quinn`_) Bugs diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index f7f820f8902..f71b782cb8c 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -19,7 +19,7 @@ from ...utils import verbose, logger, warn from ..utils import _blk_read_lims, _mult_cal_one -from ..base import BaseRaw +from ..base import BaseRaw, _get_scaling from ..meas_info import _empty_info, _unique_channel_names from ..constants import FIFF from ...filter import resample @@ -83,6 +83,7 @@ class RawEDF(BaseRaw): .. versionadded:: 1.1 %(preload)s + %(units_edf_bdf_io)s %(verbose)s See Also @@ -132,7 +133,7 @@ class RawEDF(BaseRaw): @verbose def __init__(self, input_fname, eog=None, misc=None, stim_channel='auto', exclude=(), infer_types=False, preload=False, include=None, - verbose=None): + units=None, *, verbose=None): logger.info('Extracting EDF parameters from {}...'.format(input_fname)) input_fname = os.path.abspath(input_fname) info, edf_info, orig_units = _get_info(input_fname, stim_channel, eog, @@ -140,6 +141,22 @@ def __init__(self, input_fname, eog=None, misc=None, stim_channel='auto', preload, include) logger.info('Creating raw.info structure...') + if units is not None and isinstance(units, str): + units = {ch_name: units for ch_name in info['ch_names']} + elif units is None: + units = dict() + + for k, (this_ch, this_unit) in enumerate(orig_units.items()): + if this_unit != "" and this_unit in units: + raise ValueError(f'Unit for channel {this_ch} is present in ' + 'the file. Cannot overwrite it with the ' + 'units argument.') + if this_unit == "" and this_ch in units: + orig_units[this_ch] = units[this_ch] + ch_type = edf_info["ch_types"][k] + scaling = _get_scaling(ch_type.lower(), orig_units[this_ch]) + edf_info["units"][k] /= scaling + # Raw attributes last_samps = [edf_info['nsamples'] - 1] super().__init__(info, preload, filenames=[input_fname], @@ -1282,7 +1299,7 @@ def _find_tal_idx(ch_names): @fill_doc def read_raw_edf(input_fname, eog=None, misc=None, stim_channel='auto', exclude=(), infer_types=False, include=None, preload=False, - verbose=None): + units=None, *, verbose=None): """Reader function for EDF or EDF+ files. Parameters @@ -1322,6 +1339,7 @@ def read_raw_edf(input_fname, eog=None, misc=None, stim_channel='auto', .. versionadded:: 1.1 %(preload)s + %(units_edf_bdf_io)s %(verbose)s Returns @@ -1384,13 +1402,13 @@ def read_raw_edf(input_fname, eog=None, misc=None, stim_channel='auto', return RawEDF(input_fname=input_fname, eog=eog, misc=misc, stim_channel=stim_channel, exclude=exclude, infer_types=infer_types, preload=preload, include=include, - verbose=verbose) + units=units, verbose=verbose) @fill_doc def read_raw_bdf(input_fname, eog=None, misc=None, stim_channel='auto', exclude=(), infer_types=False, include=None, preload=False, - verbose=None): + units=None, *, verbose=None): """Reader function for BDF files. Parameters @@ -1430,6 +1448,7 @@ def read_raw_bdf(input_fname, eog=None, misc=None, stim_channel='auto', .. versionadded:: 1.1 %(preload)s + %(units_edf_bdf_io)s %(verbose)s Returns @@ -1485,7 +1504,7 @@ def read_raw_bdf(input_fname, eog=None, misc=None, stim_channel='auto', return RawEDF(input_fname=input_fname, eog=eog, misc=misc, stim_channel=stim_channel, exclude=exclude, infer_types=infer_types, preload=preload, include=include, - verbose=verbose) + units=units, verbose=verbose) @fill_doc diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 8181b856a44..bfb17cf27e1 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3508,6 +3508,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): channel-type-specific default unit. """ +docdict['units_edf_bdf_io'] = """ +units : dict | str + The units of the channels as stored in the file. This argument + is useful only if the units are missing from the original file. + If a dict, it must map a channel name to its unit, and if str + it is assumed that all channels have the same units. +""" + docdict['units_topomap'] = """ units : dict | str | None The unit of the channel type used for colorbar label. If From 93485e025d576470c21cc936becb3992c01a9c5e Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Sat, 27 Aug 2022 09:58:16 +0100 Subject: [PATCH 3/7] add spectrum class (#10184) * STY: alphabetize imports * wip: first sketch of spectrum class [ci skip] implement __repr__; add placeholder _repr_html_ add draft of _repr_html_ default to multitaper for Evokeds make raw.plot_psd() use the new code path unify viz.plot_raw_psd code path too support unaggregated multitaper add picks param to spectrum.plot() fix(ish) the units() method allow average=False as synonym for None handle unaggregated estimates in combo with epochs fix CI plotting implement get_data method [ci skip] * refactor aggregation * fix instance type checking * improve TODO notes * WIP use new class in plot_psd_topo [ci skip] * docdict additions * implement Spectrum.to_data_frame * test Spectrum.to_data_frame * adapt to_data_frame for unaggregated spectra [ci skip] * test unaggregated welch to df * test unagg multitaper to df * make tests more similar * make DRY * fix epoch test * simplify test * fix flake * use requires_pandas * fix bad rebase * fix API for epochs * use new API in example * tiny docstring improvement * fix unused import * fix docdict key order * do it the smarter/safer way * update tests to avoid deprecated calls * better deprectation message * convert more deprecated func calls to new method * unused imports * fix tests * more unused imports * fix circular imports also: - apply isort to a couple files - revert distracting isort on otherwise barely touched file * make CIs pass * get I/O working * don't store verbose attr * test IO * fix D202 * fix flake * better docstring for save method and read func * fix compute_psd docstring * return value descr * decorate test (h5py) * fixup after rebase * reorder methods * add __getitem__ functionality * __getitem__ tests * add __eq__, test for .copy(), refactor IO test * refactor to separate epochs class * EpochsSpectrum IO * fix type checking, better variable naming * test evoked IO too [ci skip] * fix type checking some more * adjust .units() for complex multitaper data * fix flake * docstring refactor * working plot_topomap implementation * fix docstring tests * fix pydocstyle * add EpochsSpectrum to the public API * make test more DRY * tweak deprecation message * docdict/docstrings fixes * add plot_psd_topomap to mixin * pytest limitation workaround * work toward unifying plot_topomap API * don't silently overwrite units * TODO comments [ci skip] * WIP tutorial changes * TODO: plot_psd_topo * WIP plot_topo & docdict stuff * plot_topomap docstring dedup * dedup legacy n_fft default * fix varname * more WIP plot_psd_topo * finish migrating plot_psd_topo to mixin * use new code path for plot_topomap * flake * docstring tests * unused imports * whitespace * flake * flake again * add plot_topo as spectrum method * don't do too much * fix test * flake * fix test * fix tutorial * WIP spectrum class tutorial * codespell * update tutorial and repr_html template * better repr, better shape checking * tweaks from self-review * update changelog * flake * fix * undo isort / other unrelated changes * use new API in tutorial * remove redundant plt_show * standardize docstring order * explain setup.cfg entry * fix html repr of units * simplify __eq__ by improving object_diff * remove deepcopy override * fix flake8 config * remove superfluous BibTeX fields [ci skip] Co-authored-by: Marijn van Vliet * misc fixes [ci skip] Co-authored-by: Marijn van Vliet * Update mne/viz/utils.py [ci skip] Co-authored-by: Eric Larson * update old tutorials more thoroughly * file encoding / test comments Co-authored-by: Marijn van Vliet * use __setstate__ and __getstate__ * flake * fix reject_by_annot appearing where it shouldn't * fix docstrings Co-authored-by: Marijn van Vliet Co-authored-by: Eric Larson --- doc/_static/style.css | 3 + doc/changes/latest.inc | 2 + doc/conf.py | 2 + doc/references.bib | 11 + doc/time_frequency.rst | 3 + mne/channels/channels.py | 11 +- mne/cov.py | 2 + mne/decoding/transformer.py | 2 +- mne/epochs.py | 136 ++- mne/evoked.py | 94 +- mne/filter.py | 3 +- mne/html_templates/repr/spectrum.html.jinja | 50 + mne/io/base.py | 79 +- mne/preprocessing/ica.py | 5 +- mne/simulation/tests/test_raw.py | 7 +- mne/time_frequency/__init__.py | 1 + mne/time_frequency/multitaper.py | 4 +- mne/time_frequency/psd.py | 103 +- mne/time_frequency/spectrum.py | 1146 +++++++++++++++++++ mne/time_frequency/tests/test_csd.py | 6 +- mne/time_frequency/tests/test_multitaper.py | 57 +- mne/time_frequency/tests/test_psd.py | 359 ++---- mne/time_frequency/tests/test_spectrum.py | 184 +++ mne/utils/check.py | 3 +- mne/utils/dataframe.py | 2 +- mne/utils/docs.py | 192 +++- mne/utils/mixin.py | 2 +- mne/utils/numerics.py | 12 +- mne/viz/epochs.py | 43 +- mne/viz/ica.py | 9 +- mne/viz/raw.py | 123 +- mne/viz/tests/test_epochs.py | 7 +- mne/viz/tests/test_raw.py | 7 +- mne/viz/topomap.py | 114 +- mne/viz/utils.py | 25 +- setup.cfg | 5 +- tutorials/clinical/60_sleep.py | 13 +- tutorials/epochs/20_visualize_epochs.py | 55 +- tutorials/raw/40_visualize_raw.py | 56 +- tutorials/time-freq/10_spectrum_class.py | 172 +++ 40 files changed, 2397 insertions(+), 713 deletions(-) create mode 100644 mne/html_templates/repr/spectrum.html.jinja create mode 100644 mne/time_frequency/spectrum.py create mode 100644 mne/time_frequency/tests/test_spectrum.py create mode 100644 tutorials/time-freq/10_spectrum_class.py diff --git a/doc/_static/style.css b/doc/_static/style.css index 48597c2481a..7d9768b388c 100644 --- a/doc/_static/style.css +++ b/doc/_static/style.css @@ -327,3 +327,6 @@ ul.icon-bullets { img.hidden { visibility: hidden; } +td.justify { + text-align-last: justify; +} diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index b7bb6013021..b969ac037ec 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -55,3 +55,5 @@ API changes ~~~~~~~~~~~ - The ``bands`` parameter of :meth:`mne.Epochs.plot_psd_topomap` now accepts :class:`dict` input; legacy :class:`tuple` input is supported, but discouraged for new code (:gh:`11050` by `Daniel McCloy`_) - The ``show_toolbar`` argument to :class:`mne.viz.Brain` is being removed by deprecation (:gh:`11049` by `Eric Larson`_) +- New classes :class:`~mne.time_frequency.Spectrum` and :class:`~mne.time_frequency.EpochsSpectrum`, created via new methods :meth:`Raw.compute_psd()`, :meth:`Epochs.compute_psd()`, and :meth:`Evoked.compute_psd()` (:gh:`10184` by `Daniel McCloy`_) +- The PSD functions that operate on Raw/Epochs/Evoked instances (``mne.time_frequency.psd_welch`` and ``mne.time_frequency.psd_multitaper``) are deprecated; for equivalent functionality create :class:`~mne.time_frequency.Spectrum` or :class:`~mne.time_frequency.EpochsSpectrum` objects instead and then run ``spectrum.get_data(return_freqs=True)`` (:gh:`10184` by `Daniel McCloy`_) diff --git a/doc/conf.py b/doc/conf.py index 7e4b4ac537d..67f9d88ff96 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -236,6 +236,8 @@ 'Transform': 'mne.transforms.Transform', 'Coregistration': 'mne.coreg.Coregistration', 'Figure3D': 'mne.viz.Figure3D', + 'Spectrum': 'mne.time_frequency.Spectrum', + 'EpochsSpectrum': 'mne.time_frequency.EpochsSpectrum', # dipy 'dipy.align.AffineMap': 'dipy.align.imaffine.AffineMap', 'dipy.align.DiffeomorphicMap': 'dipy.align.imwarp.DiffeomorphicMap', diff --git a/doc/references.bib b/doc/references.bib index 809318b1627..ef4f2cbc9d9 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -2285,6 +2285,17 @@ @article{LuckGaspelin2017 doi = {10.1111/psyp.12639}, } +@article{Welch1967, + title = {The Use of Fast {{Fourier}} Transform for the Estimation of Power Spectra: {{A}} Method Based on Time Averaging over Short, Modified Periodograms}, + author = {Welch, Peter D.}, + year = {1967}, + journal = {IEEE Transactions on Audio and Electroacoustics}, + volume = {15}, + number = {2}, + pages = {70--73}, + doi = {10.1109/TAU.1967.1161901}, +} + @article{MaksymenkoEtAl2017, title = {Strategies for statistical thresholding of source localization maps in magnetoencephalography and estimating source extent}, volume = {290}, diff --git a/doc/time_frequency.rst b/doc/time_frequency.rst index 02ab5d28c0f..ebab1af3d26 100644 --- a/doc/time_frequency.rst +++ b/doc/time_frequency.rst @@ -16,6 +16,8 @@ Time-Frequency AverageTFR EpochsTFR CrossSpectralDensity + Spectrum + EpochsSpectrum Functions that operate on mne-python objects: @@ -36,6 +38,7 @@ Functions that operate on mne-python objects: tfr_stockwell read_tfrs write_tfrs + read_spectrum Functions that operate on ``np.ndarray`` objects: diff --git a/mne/channels/channels.py b/mne/channels/channels.py index f2f22fadc7e..988ded901ce 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -586,7 +586,7 @@ def set_meas_date(self, meas_date): class UpdateChannelsMixin(object): - """Mixin class for Raw, Evoked, Epochs, AverageTFR.""" + """Mixin class for Raw, Evoked, Epochs, Spectrum, AverageTFR.""" @verbose def pick_types(self, meg=False, eeg=False, stim=False, eog=False, @@ -791,6 +791,7 @@ def _pick_drop_channels(self, idx, *, verbose=None): # avoid circular imports from ..io import BaseRaw from ..time_frequency import AverageTFR, EpochsTFR + from ..time_frequency.spectrum import BaseSpectrum msg = 'adding, dropping, or reordering channels' if isinstance(self, BaseRaw): @@ -815,8 +816,12 @@ def _pick_drop_channels(self, idx, *, verbose=None): if mat is not None: setattr(self, key, mat[idx][:, idx]) - # All others (Evoked, Epochs, Raw) have chs axis=-2 - axis = -3 if isinstance(self, (AverageTFR, EpochsTFR)) else -2 + if isinstance(self, BaseSpectrum): + axis = self._dims.index('channel') + elif isinstance(self, (AverageTFR, EpochsTFR)): + axis = -3 + else: # All others (Evoked, Epochs, Raw) have chs axis=-2 + axis = -2 if hasattr(self, '_data'): # skip non-preloaded Raw self._data = self._data.take(idx, axis=axis) else: diff --git a/mne/cov.py b/mne/cov.py index 988011c53e5..abe81b56446 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -267,6 +267,8 @@ def plot_topomap(self, info, ch_type=None, vmin=None, ---------- %(info_not_none)s %(ch_type_topomap)s + + .. versionadded:: 0.21 %(vmin_vmax_topomap)s %(cmap_topomap)s %(sensors_topomap)s diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index 0c86384657f..a6cb26e1dda 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -12,7 +12,7 @@ from .. import pick_types from ..filter import filter_data -from ..time_frequency.psd import psd_array_multitaper +from ..time_frequency import psd_array_multitaper from ..utils import fill_doc, _check_option, _validate_type, verbose from ..io.pick import (pick_info, _pick_data_channels, _picks_by_type, _picks_to_idx) diff --git a/mne/epochs.py b/mne/epochs.py index b02b90eeea1..8cd1fc5e21f 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -37,7 +37,7 @@ pick_channels, pick_info, _pick_data_channels, _DATA_CH_TYPES_SPLIT, _picks_to_idx) from .io.proj import setup_proj, ProjMixin -from .io.base import BaseRaw, _get_ch_factors +from .io.base import BaseRaw, TimeMixin, _get_ch_factors from .bem import _check_origin from .evoked import EvokedArray from .baseline import rescale, _log_rescale, _check_baseline @@ -49,13 +49,14 @@ from .event import (_read_events_fif, make_fixed_length_events, match_event_names) from .fixes import rng_uniform -from .viz import (plot_epochs, plot_epochs_psd, plot_epochs_psd_topomap, - plot_epochs_image, plot_topo_image_epochs, plot_drop_log) +from .time_frequency.spectrum import EpochsSpectrum, SpectrumMixin +from .viz import (plot_epochs, plot_epochs_image, + plot_topo_image_epochs, plot_drop_log) from .utils import (_check_fname, check_fname, logger, verbose, check_random_state, warn, _pl, sizeof_fmt, SizeMixin, copy_function_doc_to_method_doc, _check_pandas_installed, - _check_preload, GetEpochsMixin, TimeMixin, + _check_preload, GetEpochsMixin, _prepare_read_metadata, _prepare_write_metadata, _check_event_id, _gen_events, _check_option, _check_combine, _build_data_frame, @@ -340,7 +341,8 @@ def _handle_event_repeated(events, event_id, event_repeated, selection, @fill_doc class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, InterpolationMixin, FilterMixin, - TimeMixin, SizeMixin, GetEpochsMixin, EpochAnnotationsMixin): + TimeMixin, SizeMixin, GetEpochsMixin, EpochAnnotationsMixin, + SpectrumMixin): """Abstract base class for `~mne.Epochs`-type classes. .. warning:: This class provides basic functionality and should never be @@ -1122,41 +1124,6 @@ def plot(self, picks=None, scalings=None, n_epochs=20, n_channels=20, use_opengl=use_opengl, theme=theme, overview_mode=overview_mode) - @copy_function_doc_to_method_doc(plot_epochs_psd) - def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, - proj=False, bandwidth=None, adaptive=False, low_bias=True, - normalization='length', picks=None, ax=None, color='black', - xscale='linear', area_mode='std', area_alpha=0.33, - dB=True, estimate='auto', show=True, n_jobs=None, - average=False, line_alpha=None, spatial_colors=True, - sphere=None, exclude='bads', verbose=None): - return plot_epochs_psd(self, fmin=fmin, fmax=fmax, tmin=tmin, - tmax=tmax, proj=proj, bandwidth=bandwidth, - adaptive=adaptive, low_bias=low_bias, - normalization=normalization, picks=picks, ax=ax, - color=color, xscale=xscale, area_mode=area_mode, - area_alpha=area_alpha, dB=dB, estimate=estimate, - show=show, n_jobs=n_jobs, average=average, - line_alpha=line_alpha, - spatial_colors=spatial_colors, sphere=sphere, - exclude=exclude, verbose=verbose) - - @copy_function_doc_to_method_doc(plot_epochs_psd_topomap) - def plot_psd_topomap(self, bands=None, tmin=None, - tmax=None, proj=False, bandwidth=None, adaptive=False, - low_bias=True, normalization='length', ch_type=None, - cmap=None, agg_fun=None, dB=True, - n_jobs=None, normalize=False, cbar_fmt='auto', - outlines='head', axes=None, show=True, - sphere=None, vlim=(None, None), verbose=None): - return plot_epochs_psd_topomap( - self, bands=bands, tmin=tmin, tmax=tmax, - proj=proj, bandwidth=bandwidth, adaptive=adaptive, - low_bias=low_bias, normalization=normalization, ch_type=ch_type, - cmap=cmap, agg_fun=agg_fun, dB=dB, n_jobs=n_jobs, - normalize=normalize, cbar_fmt=cbar_fmt, outlines=outlines, - axes=axes, show=show, sphere=sphere, vlim=vlim, verbose=verbose) - @copy_function_doc_to_method_doc(plot_topo_image_epochs) def plot_topo_image(self, layout=None, sigma=0., vmin=None, vmax=None, colorbar=None, order=None, cmap='RdBu_r', @@ -2021,6 +1988,95 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): # actually remove the indices return self, indices + @verbose + def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, + tmax=None, picks=None, proj=False, *, n_jobs=1, + verbose=None, **method_kw): + """Perform spectral analysis on sensor data. + + Parameters + ---------- + %(method_psd)s + Default is ``'multitaper'``. + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s + + Returns + ------- + spectrum : instance of EpochsSpectrum + The spectral representation of each epoch. + + References + ---------- + .. footbibliography:: + """ + return EpochsSpectrum( + self, method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, + picks=picks, proj=proj, n_jobs=n_jobs, verbose=verbose, + **method_kw) + + @verbose + def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, + proj=False, *, method='auto', average=False, dB=True, + estimate='auto', xscale='linear', area_mode='std', + area_alpha=0.33, color='black', line_alpha=None, + spatial_colors=True, sphere=None, exclude='bads', ax=None, + show=True, n_jobs=1, verbose=None, **method_kw): + """%(plot_psd_doc)s. + + Parameters + ---------- + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(method_plot_psd_auto)s + %(average_plot_psd)s + %(dB_plot_psd)s + %(estimate_plot_psd)s + %(xscale_plot_psd)s + %(area_mode_plot_psd)s + %(area_alpha_plot_psd)s + %(color_plot_psd)s + %(line_alpha_plot_psd)s + %(spatial_colors_psd)s + %(sphere_topomap_auto)s + + .. versionadded:: 0.22.0 + exclude : list of str | 'bads' + Channels names to exclude from being shown. If 'bads', the bad + channels are excluded. Pass an empty list to plot all channels + (including channels marked "bad", if any). + + .. versionadded:: 0.24.0 + %(ax_plot_psd)s + %(show)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s + + Returns + ------- + fig : instance of Figure + Figure with frequency spectra of the data channels. + + Notes + ----- + %(notes_plot_psd_meth)s + """ + return super().plot_psd( + fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks, proj=proj, + reject_by_annotation=False, method=method, average=average, dB=dB, + estimate=estimate, xscale=xscale, area_mode=area_mode, + area_alpha=area_alpha, color=color, line_alpha=line_alpha, + spatial_colors=spatial_colors, sphere=sphere, exclude=exclude, + ax=ax, show=show, n_jobs=n_jobs, verbose=verbose, **method_kw) + @verbose def to_data_frame(self, picks=None, index=None, scalings=None, copy=True, long_format=False, diff --git a/mne/evoked.py b/mne/evoked.py index 18ccc250b24..84ddc245834 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -9,6 +9,7 @@ # License: BSD-3-Clause from copy import deepcopy + import numpy as np from .baseline import rescale, _log_rescale, _check_baseline @@ -43,6 +44,7 @@ write_id, write_float, write_complex_float_matrix) from .io.base import _check_maxshield, _get_ch_factors from .parallel import parallel_func +from .time_frequency.spectrum import Spectrum, SpectrumMixin _aspect_dict = { 'average': FIFF.FIFFV_ASPECT_AVERAGE, @@ -62,7 +64,8 @@ @fill_doc class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, - InterpolationMixin, FilterMixin, TimeMixin, SizeMixin): + InterpolationMixin, FilterMixin, TimeMixin, SizeMixin, + SpectrumMixin): """Evoked data. Parameters @@ -728,6 +731,95 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, return out + @verbose + def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, + tmax=None, picks=None, proj=False, *, n_jobs=1, + verbose=None, **method_kw): + """Perform spectral analysis on sensor data. + + Parameters + ---------- + %(method_psd)s + Default is ``'multitaper'``. + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s + + Returns + ------- + spectrum : instance of Spectrum + The spectral representation of the data. + + References + ---------- + .. footbibliography:: + """ + return Spectrum( + self, method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, + picks=picks, proj=proj, reject_by_annotation=False, n_jobs=n_jobs, + verbose=verbose, **method_kw) + + @verbose + def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, + proj=False, *, method='auto', average=False, dB=True, + estimate='auto', xscale='linear', area_mode='std', + area_alpha=0.33, color='black', line_alpha=None, + spatial_colors=True, sphere=None, exclude='bads', ax=None, + show=True, n_jobs=1, verbose=None, **method_kw): + """%(plot_psd_doc)s. + + Parameters + ---------- + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(method_plot_psd_auto)s + %(average_plot_psd)s + %(dB_plot_psd)s + %(estimate_plot_psd)s + %(xscale_plot_psd)s + %(area_mode_plot_psd)s + %(area_alpha_plot_psd)s + %(color_plot_psd)s + %(line_alpha_plot_psd)s + %(spatial_colors_psd)s + %(sphere_topomap_auto)s + + .. versionadded:: 0.22.0 + exclude : list of str | 'bads' + Channels names to exclude from being shown. If 'bads', the bad + channels are excluded. Pass an empty list to plot all channels + (including channels marked "bad", if any). + + .. versionadded:: 0.24.0 + %(ax_plot_psd)s + %(show)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s + + Returns + ------- + fig : instance of Figure + Figure with frequency spectra of the data channels. + + Notes + ----- + %(notes_plot_psd_meth)s + """ + return super().plot_psd( + fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks, proj=proj, + reject_by_annotation=False, method=method, average=average, dB=dB, + estimate=estimate, xscale=xscale, area_mode=area_mode, + area_alpha=area_alpha, color=color, line_alpha=line_alpha, + spatial_colors=spatial_colors, sphere=sphere, exclude=exclude, + ax=ax, show=show, n_jobs=n_jobs, verbose=verbose, **method_kw) + @verbose def to_data_frame(self, picks=None, index=None, scalings=None, copy=True, long_format=False, diff --git a/mne/filter.py b/mne/filter.py index 16b387e4886..1a0c742786b 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -11,7 +11,6 @@ from .cuda import (_setup_cuda_fft_multiply_repeated, _fft_multiply_repeated, _setup_cuda_fft_resample, _fft_resample, _smart_pad) from .parallel import parallel_func -from .time_frequency.multitaper import _mt_spectra, _compute_mt_params from .utils import (logger, verbose, sum_squared, warn, _pl, _check_preload, _validate_type, _check_option, _ensure_int) from ._ola import _COLA @@ -1203,6 +1202,7 @@ def _get_window_thresh(n_times, sfreq, mt_bandwidth, p_value): # but if we have a new enough scipy, # it's only ~0.175 sec for 8 tapers even with 100000 samples from scipy import stats + from .time_frequency.multitaper import _compute_mt_params dpss_n_times_max = 100000 # figure out what tapers to use @@ -1297,6 +1297,7 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, Based on Chronux. If line_freqs is specified, all freqs within notch_width of each line_freq is set to zero. """ + from .time_frequency.multitaper import _mt_spectra assert x.ndim == 1 if x.shape[-1] != window_fun.shape[-1]: window_fun, threshold = get_thresh(x.shape[-1]) diff --git a/mne/html_templates/repr/spectrum.html.jinja b/mne/html_templates/repr/spectrum.html.jinja new file mode 100644 index 00000000000..ee35f8e2ec1 --- /dev/null +++ b/mne/html_templates/repr/spectrum.html.jinja @@ -0,0 +1,50 @@ + + + + + + {%- for unit in units %} + + {%- if loop.index == 1 %} + + {%- endif %} + + + {%- endfor %} + + + + + {%- if inst_type == "Epochs" %} + + + + + {% endif -%} + + + + + + + + + {% if "taper" in spectrum._dims %} + + + + + {% endif %} + + + + + + + + + + + + +
Data type{{ spectrum._data_type }}
Units{{ unit }}
Data source{{ inst_type }}
Number of epochs{{ spectrum.shape[0] }}
Dims{{ spectrum._dims | join(", ") }}
Estimation method{{ spectrum.method }}
Number of tapers{{ spectrum._mt_weights.size }}
Number of channels{{ spectrum.ch_names|length }}
Number of frequency bins{{ spectrum.freqs|length }}
Frequency range{{ '%.2f'|format(spectrum.freqs[0]) }} – {{ '%.2f'|format(spectrum.freqs[-1]) }} Hz
diff --git a/mne/io/base.py b/mne/io/base.py index a216aa8a9a2..4e9b3dc0cba 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -34,8 +34,9 @@ write_complex64, write_complex128, write_int, write_id, write_string, _get_split_size, _NEXT_FILE_BUFFER) -from ..annotations import (_annotations_starts_stops, _write_annotations, - _handle_meas_date) +from ..annotations import (Annotations, _annotations_starts_stops, + _combine_annotations, _handle_meas_date, + _sync_onset, _write_annotations) from ..filter import (FilterMixin, notch_filter, resample, _resamp_ratio_len, _resample_stim_channels, _check_fun) from ..parallel import parallel_func @@ -48,14 +49,15 @@ _build_data_frame, _convert_times, _scale_dataframe_data, _check_time_format, _arange_div, TimeMixin) from ..defaults import _handle_default -from ..viz import plot_raw, plot_raw_psd, plot_raw_psd_topo, _RAW_CLIP_DEF +from ..viz import plot_raw, _RAW_CLIP_DEF from ..event import find_events, concatenate_events -from ..annotations import Annotations, _combine_annotations, _sync_onset +from ..time_frequency.spectrum import Spectrum, SpectrumMixin @fill_doc class BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, - InterpolationMixin, TimeMixin, SizeMixin, FilterMixin): + InterpolationMixin, TimeMixin, SizeMixin, FilterMixin, + SpectrumMixin): """Base class for Raw data. Parameters @@ -1535,39 +1537,6 @@ def plot(self, events=None, duration=10.0, start=0.0, n_channels=20, theme=theme, overview_mode=overview_mode, verbose=verbose) - @verbose - @copy_function_doc_to_method_doc(plot_raw_psd) - def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, proj=False, - n_fft=None, n_overlap=0, reject_by_annotation=True, - picks=None, ax=None, color='black', xscale='linear', - area_mode='std', area_alpha=0.33, dB=True, estimate='auto', - show=True, n_jobs=None, average=False, line_alpha=None, - spatial_colors=True, sphere=None, window='hamming', - exclude='bads', verbose=None): - return plot_raw_psd(self, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, - proj=proj, n_fft=n_fft, n_overlap=n_overlap, - reject_by_annotation=reject_by_annotation, - picks=picks, ax=ax, color=color, xscale=xscale, - area_mode=area_mode, area_alpha=area_alpha, - dB=dB, estimate=estimate, show=show, n_jobs=n_jobs, - average=average, line_alpha=line_alpha, - spatial_colors=spatial_colors, sphere=sphere, - window=window, exclude=exclude, verbose=verbose) - - @copy_function_doc_to_method_doc(plot_raw_psd_topo) - def plot_psd_topo(self, tmin=0., tmax=None, fmin=0, fmax=100, proj=False, - n_fft=2048, n_overlap=0, layout=None, color='w', - fig_facecolor='k', axis_facecolor='k', dB=True, - show=True, block=False, n_jobs=None, axes=None, - verbose=None): - return plot_raw_psd_topo(self, tmin=tmin, tmax=tmax, fmin=fmin, - fmax=fmax, proj=proj, n_fft=n_fft, - n_overlap=n_overlap, layout=layout, - color=color, fig_facecolor=fig_facecolor, - axis_facecolor=axis_facecolor, dB=dB, - show=show, block=block, n_jobs=n_jobs, - axes=axes, verbose=verbose) - @property def ch_names(self): """Channel names.""" @@ -1833,6 +1802,40 @@ def _get_buffer_size(self, buffer_size_sec=None): buffer_size_sec = float(buffer_size_sec) return int(np.ceil(buffer_size_sec * self.info['sfreq'])) + @verbose + def compute_psd(self, method='welch', fmin=0, fmax=np.inf, tmin=None, + tmax=None, picks=None, proj=False, + reject_by_annotation=True, *, n_jobs=1, verbose=None, + **method_kw): + """Perform spectral analysis on sensor data. + + Parameters + ---------- + %(method_psd)s + Default is ``'welch'``. + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(reject_by_annotation_psd)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s + + Returns + ------- + spectrum : instance of Spectrum + The spectral representation of the data. + + References + ---------- + .. footbibliography:: + """ + return Spectrum( + self, method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, + picks=picks, proj=proj, reject_by_annotation=reject_by_annotation, + n_jobs=n_jobs, verbose=verbose, **method_kw) + @verbose def to_data_frame(self, picks=None, index=None, scalings=None, copy=True, start=None, stop=None, diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 71954d5d307..28797ae8239 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -49,7 +49,6 @@ from ..channels.channels import _contains_ch_type from ..channels.layout import _find_topomap_coords -from ..time_frequency import psd_welch, psd_multitaper from ..io.write import start_and_end_file, write_id from ..utils import (logger, check_fname, _check_fname, verbose, _reject_data_segments, check_random_state, _validate_type, @@ -1648,8 +1647,8 @@ def find_bads_muscle(self, inst, threshold=0.5, start=None, components = self.get_components() # compute metric #1: slope of the log-log psd - psd_func = psd_welch if isinstance(inst, BaseRaw) else psd_multitaper - psds, freqs = psd_func(sources, fmin=l_freq, fmax=h_freq, picks='misc') + spectrum = sources.compute_psd(fmin=l_freq, fmax=h_freq, picks='misc') + psds, freqs = spectrum.get_data(return_freqs=True) slopes = np.polyfit(np.log10(freqs), np.log10(psds).T, 1)[0] # compute metric #2: distance from the vertex of focus diff --git a/mne/simulation/tests/test_raw.py b/mne/simulation/tests/test_raw.py index c584291a9a5..cbddd5cb70d 100644 --- a/mne/simulation/tests/test_raw.py +++ b/mne/simulation/tests/test_raw.py @@ -31,7 +31,6 @@ from mne.surface import _get_ico_surface from mne.io import read_raw_fif, RawArray from mne.io.constants import FIFF -from mne.time_frequency import psd_welch from mne.utils import catch_logging, check_version base_path = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data') @@ -488,8 +487,10 @@ def test_simulate_raw_chpi(): picks_eeg = pick_types(raw.info, meg=False, eeg=True) for picks in [picks_meg[:3], picks_eeg[:3]]: - psd_sim, freqs_sim = psd_welch(raw_sim, picks=picks) - psd_chpi, freqs_chpi = psd_welch(raw_chpi, picks=picks) + psd_sim, freqs_sim = ( + raw_sim.compute_psd(picks=picks).get_data(return_freqs=True)) + psd_chpi, freqs_chpi = ( + raw_chpi.compute_psd(picks=picks).get_data(return_freqs=True)) assert_array_equal(freqs_sim, freqs_chpi) freq_idx = np.sort([np.argmin(np.abs(freqs_sim - f)) diff --git a/mne/time_frequency/__init__.py b/mne/time_frequency/__init__.py index 439487d66e6..da10b51982e 100644 --- a/mne/time_frequency/__init__.py +++ b/mne/time_frequency/__init__.py @@ -10,5 +10,6 @@ from .ar import fit_iir_model_raw from .multitaper import (dpss_windows, psd_array_multitaper, tfr_array_multitaper) +from .spectrum import EpochsSpectrum, Spectrum, read_spectrum from ._stft import stft, istft, stftfreq from ._stockwell import tfr_stockwell, tfr_array_stockwell diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index da21057a9f4..14e43cd4a53 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -4,8 +4,10 @@ # Parts of this code were copied from NiTime http://nipy.sourceforge.net/nitime import operator + import numpy as np +from ..filter import next_fast_len from ..parallel import parallel_func from ..utils import sum_squared, warn, verbose, logger, _check_option @@ -61,7 +63,7 @@ def dpss_windows(N, half_nbw, Kmax, low_bias=True, interp_from=None, from scipy import interpolate from scipy.fft import rfft, irfft from scipy.signal.windows import dpss as sp_dpss - from ..filter import next_fast_len + # This np.int32 business works around a weird Windows bug, see # gh-5039 and https://github.com/scipy/scipy/pull/8608 Kmax = np.int32(operator.index(Kmax)) diff --git a/mne/time_frequency/psd.py b/mne/time_frequency/psd.py index fe70d7fd1b3..19cfe1b003b 100644 --- a/mne/time_frequency/psd.py +++ b/mne/time_frequency/psd.py @@ -3,12 +3,16 @@ # License : BSD-3-Clause from functools import partial + import numpy as np from ..parallel import parallel_func -from ..io.pick import _picks_to_idx -from ..utils import logger, verbose, _time_mask, _check_option -from .multitaper import psd_array_multitaper +from ..utils import logger, verbose, deprecated, _check_option + +_psd_deprecation_msg = ( + 'Function psd_{0}() is deprecated; for Raw/Epochs/Evoked instances use ' + 'spectrum = instance.compute_psd(method="{0}") instead, followed by ' + 'spectrum.get_data(return_freqs=True).') # adapted from SciPy @@ -89,36 +93,6 @@ def _check_nfft(n, n_fft, n_per_seg, n_overlap): return n_fft, n_per_seg, n_overlap -def _check_psd_data(inst, tmin, tmax, picks, proj, reject_by_annotation=False): - """Check PSD data / pull arrays from inst.""" - from ..io.base import BaseRaw - from ..epochs import BaseEpochs - from ..evoked import Evoked - if not isinstance(inst, (BaseEpochs, BaseRaw, Evoked)): - raise ValueError( - f'inst must be an instance of Epochs, Raw, or Evoked. Got ' - f'{type(inst)}' - ) - - time_mask = _time_mask(inst.times, tmin, tmax, sfreq=inst.info['sfreq']) - picks = _picks_to_idx(inst.info, picks, 'data', with_ref_meg=False) - if proj: - # Copy first so it's not modified - inst = inst.copy().apply_proj() - - sfreq = inst.info['sfreq'] - if isinstance(inst, BaseRaw): - start, stop = np.where(time_mask)[0][[0, -1]] - rba = 'NaN' if reject_by_annotation else None - data = inst.get_data(picks, start, stop + 1, reject_by_annotation=rba) - elif isinstance(inst, BaseEpochs): - data = inst.get_data(picks=picks)[:, :, time_mask] - else: # Evoked - data = inst.data[picks][:, time_mask] - - return data, sfreq - - @verbose def psd_array_welch(x, sfreq, fmin=0, fmax=np.inf, n_fft=256, n_overlap=0, n_per_seg=None, n_jobs=None, average='mean', @@ -169,7 +143,9 @@ def psd_array_welch(x, sfreq, fmin=0, fmax=np.inf, n_fft=256, n_overlap=0, ----- .. versionadded:: 0.14.0 """ - _check_option('average', average, (None, 'mean', 'median')) + _check_option('average', average, (None, False, 'mean', 'median')) + if average is False: + average = None dshape = x.shape[:-1] n_times = x.shape[-1] @@ -210,6 +186,7 @@ def psd_array_welch(x, sfreq, fmin=0, fmax=np.inf, n_fft=256, n_overlap=0, return psds, freqs +@deprecated(_psd_deprecation_msg.format('welch')) @verbose def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256, n_overlap=0, n_per_seg=None, picks=None, proj=False, n_jobs=None, @@ -218,20 +195,14 @@ def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256, """Compute the power spectral density (PSD) using Welch's method. Calculates periodograms for a sliding window over the time dimension, then - averages them together for each channel/epoch. + optionally averages them together for each channel/epoch. Parameters ---------- inst : instance of Epochs or Raw or Evoked The data for PSD calculation. - fmin : float - Min frequency of interest. - fmax : float - Max frequency of interest. - tmin : float | None - Min time of interest. - tmax : float | None - Max time of interest. + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s n_fft : int The length of FFT used, must be ``>= n_per_seg`` (default: 256). The segments will be zero-padded if ``n_fft > n_per_seg``. @@ -244,8 +215,7 @@ def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256, Length of each Welch segment (windowed with a Hamming window). Defaults to None, which sets n_per_seg equal to n_fft. %(picks_good_data_noref)s - proj : bool - Apply SSP projection vectors. If inst is ndarray this is not used. + %(proj_psd)s %(n_jobs)s %(reject_by_annotation_raw)s @@ -272,6 +242,8 @@ def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256, See Also -------- + Spectrum + EpochsSpectrum mne.io.Raw.plot_psd mne.Epochs.plot_psd psd_multitaper @@ -281,15 +253,14 @@ def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256, ----- .. versionadded:: 0.12.0 """ - # Prep data - data, sfreq = _check_psd_data(inst, tmin, tmax, picks, proj, - reject_by_annotation=reject_by_annotation) - return psd_array_welch(data, sfreq, fmin=fmin, fmax=fmax, n_fft=n_fft, - n_overlap=n_overlap, n_per_seg=n_per_seg, - average=average, n_jobs=n_jobs, window=window, - verbose=verbose) + spectrum = inst.compute_psd( + 'welch', fmin, fmax, tmin, tmax, picks, proj, reject_by_annotation, + n_jobs=n_jobs, verbose=verbose, n_fft=n_fft, n_overlap=n_overlap, + n_per_seg=n_per_seg, average=average, window=window) + return spectrum.get_data(return_freqs=True) +@deprecated(_psd_deprecation_msg.format('multitaper')) @verbose def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, bandwidth=None, adaptive=False, low_bias=True, @@ -306,14 +277,8 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, ---------- inst : instance of Epochs or Raw or Evoked The data for PSD calculation. - fmin : float - Min frequency of interest. - fmax : float - Max frequency of interest. - tmin : float | None - Min time of interest. - tmax : float | None - Max time of interest. + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s bandwidth : float The bandwidth of the multi taper windowing function in Hz. The default value is a window half-bandwidth of 4. @@ -325,8 +290,7 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, bandwidth. %(normalization)s %(picks_good_data_noref)s - proj : bool - Apply SSP projection vectors. If inst is ndarray this is not used. + %(proj_psd)s %(n_jobs)s %(reject_by_annotation_raw)s %(verbose)s @@ -342,6 +306,8 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, See Also -------- + Spectrum + EpochsSpectrum mne.io.Raw.plot_psd mne.Epochs.plot_psd psd_array_multitaper @@ -356,10 +322,9 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, ---------- .. footbibliography:: """ - # Prep data - data, sfreq = _check_psd_data(inst, tmin, tmax, picks, proj, - reject_by_annotation=reject_by_annotation) - return psd_array_multitaper(data, sfreq, fmin=fmin, fmax=fmax, - bandwidth=bandwidth, adaptive=adaptive, - low_bias=low_bias, normalization=normalization, - n_jobs=n_jobs, verbose=verbose) + spectrum = inst.compute_psd( + 'multitaper', fmin, fmax, tmin, tmax, picks, proj, + reject_by_annotation, n_jobs=n_jobs, verbose=verbose, + bandwidth=bandwidth, adaptive=adaptive, low_bias=low_bias, + normalization=normalization) + return spectrum.get_data(return_freqs=True) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py new file mode 100644 index 00000000000..660f886e7f3 --- /dev/null +++ b/mne/time_frequency/spectrum.py @@ -0,0 +1,1146 @@ +# -*- coding: utf-8 -*- +"""Container classes for spectral data.""" + +# Authors: Dan McCloy +# +# License: BSD-3-Clause + +from copy import deepcopy +from functools import partial +from inspect import signature + +import numpy as np + +from ..channels.channels import UpdateChannelsMixin, _get_ch_type +from ..channels.layout import _merge_ch_data +# from ..defaults import (_BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, +# _INTERPOLATION_DEFAULT) +from ..defaults import _handle_default +from ..io.meas_info import ContainsMixin +from ..io.pick import _pick_data_channels, _picks_to_idx, pick_info +from ..utils import (GetEpochsMixin, _build_data_frame, + _check_pandas_index_arguments, _check_pandas_installed, + _check_sphere, _time_mask, _validate_type, fill_doc, + logger, object_diff, verbose, warn) +from ..utils.check import (_check_fname, _check_option, _import_h5io_funcs, + _is_numeric, check_fname) +from ..utils.misc import _pl +from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo +from ..viz.topomap import (_make_head_outlines, _prepare_topomap_plot, + plot_psds_topomap) +from ..viz.utils import _plot_psd, plt_show +from . import psd_array_multitaper, psd_array_welch +from .psd import _check_nfft + + +def _identity_function(x): + return x + + +class SpectrumMixin(): + """Mixin providing spectral plotting methods to sensor-space containers.""" + + @verbose + def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, + proj=False, reject_by_annotation=True, *, method='auto', + average=False, dB=True, estimate='auto', xscale='linear', + area_mode='std', area_alpha=0.33, color='black', + line_alpha=None, spatial_colors=True, sphere=None, + exclude='bads', ax=None, show=True, n_jobs=1, verbose=None, + **method_kw): + """%(plot_psd_doc)s. + + Parameters + ---------- + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(reject_by_annotation_psd)s + %(method_plot_psd_auto)s + %(average_plot_psd)s + %(dB_plot_psd)s + %(estimate_plot_psd)s + %(xscale_plot_psd)s + %(area_mode_plot_psd)s + %(area_alpha_plot_psd)s + %(color_plot_psd)s + %(line_alpha_plot_psd)s + %(spatial_colors_psd)s + %(sphere_topomap_auto)s + + .. versionadded:: 0.22.0 + exclude : list of str | 'bads' + Channels names to exclude from being shown. If 'bads', the bad + channels are excluded. Pass an empty list to plot all channels + (including channels marked "bad", if any). + + .. versionadded:: 0.24.0 + %(ax_plot_psd)s + %(show)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s + + Returns + ------- + fig : instance of Figure + Figure with frequency spectra of the data channels. + + Notes + ----- + %(notes_plot_psd_meth)s + """ + from ..io import BaseRaw + + self._set_legacy_nfft_default(tmin, tmax, method, method_kw) + # triage reject_by_annotation + rba = dict() + if isinstance(self, BaseRaw): + rba = dict(reject_by_annotation=reject_by_annotation) + + spectrum = self.compute_psd( + method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, + picks=picks, proj=proj, n_jobs=n_jobs, verbose=verbose, **rba, + **method_kw) + + # translate kwargs + amplitude = 'auto' if estimate == 'auto' else (estimate == 'amplitude') + ci = 'sd' if area_mode == 'std' else area_mode + # ↓ here picks="all" because we've already restricted the `info` to + # ↓ have only `picks` channels + fig = spectrum.plot( + picks='all', average=average, dB=dB, amplitude=amplitude, + xscale=xscale, ci=ci, ci_alpha=area_alpha, color=color, + alpha=line_alpha, spatial_colors=spatial_colors, sphere=sphere, + exclude=exclude, axes=ax, show=show) + return fig + + @verbose + def plot_psd_topo(self, tmin=None, tmax=None, fmin=0, fmax=100, proj=False, + *, method='auto', dB=True, layout=None, color='w', + fig_facecolor='k', axis_facecolor='k', axes=None, + block=False, show=True, n_jobs=None, verbose=None, + **method_kw): + """Plot power spectral density, separately for each channel. + + Parameters + ---------- + %(tmin_tmax_psd)s + %(fmin_fmax_psd_topo)s + %(proj_psd)s + %(method_plot_psd_auto)s + %(dB_spectrum_plot_topo)s + %(layout_spectrum_plot_topo)s + %(color_spectrum_plot_topo)s + %(fig_facecolor)s + %(axis_facecolor)s + %(axes_spectrum_plot_topo)s + %(block)s + %(show)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s Defaults to ``dict(n_fft=2048)``. + + Returns + ------- + fig : instance of matplotlib.figure.Figure + Figure distributing one image per channel across sensor topography. + """ + self._set_legacy_nfft_default(tmin, tmax, method, method_kw) + + spectrum = self.compute_psd( + method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, + proj=proj, n_jobs=n_jobs, verbose=verbose, **method_kw) + + return spectrum.plot_topo( + dB=dB, layout=layout, color=color, fig_facecolor=fig_facecolor, + axis_facecolor=axis_facecolor, axes=axes, block=block, show=show) + + @verbose + def plot_psd_topomap(self, bands=None, tmin=None, tmax=None, proj=False, + method='auto', ch_type=None, *, normalize=False, + agg_fun=None, dB=False, # sensors=True, + # show_names=False, mask=None, mask_params=None, + # contours=6, + outlines='head', sphere=None, + # image_interp=_INTERPOLATION_DEFAULT, + # extrapolate=_EXTRAPOLATE_DEFAULT, + # border=_BORDER_DEFAULT, res=64, size=1, + cmap=None, vlim=(None, None), # colorbar=True, + cbar_fmt='auto', units=None, + axes=None, show=True, n_jobs=None, verbose=None, + **method_kw): + """Plot scalp topography of PSD for chosen frequency bands. + + Parameters + ---------- + %(bands_psd_topo)s + %(tmin_tmax_psd)s + %(proj_psd)s + %(method_plot_psd_auto)s + %(ch_type_psd_topomap)s + %(normalize_psd_topo)s + %(agg_fun_psd_topo)s + %(dB_plot_topomap)s + %(outlines_topomap)s + %(sphere_topomap_auto)s + %(cmap_psd_topo)s + %(vlim_psd_topo_joint)s + %(cbar_fmt_psd_topo)s + %(units_topomap)s + %(axes_plot_topomap)s + %(show)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s + + Returns + ------- + fig : instance of Figure + Figure showing one scalp topography per frequency band. + """ + # add after dB + # %(sensors_topomap)s + # %(show_names_topomap)s + # %(mask_evoked_topomap)s + # %(mask_params_topomap)s + # %(contours_topomap)s + # add after sphere + # %(image_interp_topomap)s + # %(extrapolate_topomap)s + # %(border_topomap)s + # %(res_topomap)s + # %(size_topomap)s + # add after vlim + # %(colorbar_topomap)s + spectrum = self.compute_psd( + method=method, tmin=tmin, tmax=tmax, proj=proj, + n_jobs=n_jobs, verbose=verbose, **method_kw) + + fig = spectrum.plot_topomap( + bands=bands, ch_type=ch_type, normalize=normalize, agg_fun=agg_fun, + dB=dB, # sensors=sensors, show_names=show_names, mask=mask, + # mask_params=mask_params, contours=contours, + outlines=outlines, sphere=sphere, + # image_interp=image_interp, extrapolate=extrapolate, + # border=border, res=res, size=size, + cmap=cmap, vlim=vlim, # colorbar=colorbar, + cbar_fmt=cbar_fmt, units=units, axes=axes, show=show) + return fig + + def _set_legacy_nfft_default(self, tmin, tmax, method, method_kw): + """Update method_kw with legacy n_fft default for plot_psd[_topo](). + + This method returns ``None`` and has a side effect of (maybe) updating + the ``method_kw`` dict. + """ + if method == 'welch' and method_kw.get('n_fft', None) is None: + tm = _time_mask(self.times, tmin, tmax, sfreq=self.info['sfreq']) + method_kw['n_fft'] = min(np.sum(tm), 2048) + + +class BaseSpectrum(ContainsMixin, UpdateChannelsMixin): + """Base class for Spectrum and EpochsSpectrum.""" + + def __init__(self, inst, method, fmin, fmax, tmin, tmax, picks, + proj, *, n_jobs, verbose, **method_kw): + # arg checking + self._sfreq = inst.info['sfreq'] + if np.isfinite(fmax) and (fmax > self.sfreq / 2): + raise ValueError( + f'Requested fmax ({fmax} Hz) must not exceed ½ the sampling ' + f'frequency of the data ({0.5 * inst.info["sfreq"]} Hz).') + # method + self._inst_type = type(inst) + if method == 'auto': + method = ('welch' if self._get_instance_type_string() == 'Raw' + else 'multitaper') + _check_option('method', method, ('welch', 'multitaper')) + + # triage method and kwargs. partial() doesn't check validity of kwargs, + # so we do it manually to save compute time if any are invalid. + psd_funcs = dict(welch=psd_array_welch, + multitaper=psd_array_multitaper) + invalid_ix = np.in1d(list(method_kw), + list(signature(psd_funcs[method]).parameters), + invert=True) + if invalid_ix.any(): + invalid_kw = np.array(list(method_kw))[invalid_ix].tolist() + s = _pl(invalid_kw) + raise TypeError( + f'Got unexpected keyword argument{s} {", ".join(invalid_kw)} ' + f'for PSD method "{method}".') + self._psd_func = partial(psd_funcs[method], **method_kw) + + # apply proj if desired + if proj: + inst = inst.copy().apply_proj() + + # prep times and picks + self._time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq) + self._picks = _picks_to_idx(inst.info, picks, 'data', + with_ref_meg=False) + + # add the info object. bads and non-data channels were dropped by + # _picks_to_idx() so we update the info accordingly: + self.info = pick_info(inst.info, sel=self._picks, copy=True) + + # assign some attributes + self.preload = True # needed for __getitem__, doesn't mean anything + self._method = method + # self._dims may also get updated by child classes + self._dims = ('channel', 'freq',) + if method_kw.get('average', '') in (None, False): + self._dims += ('segment',) + if method_kw.get('output', '') == 'complex': + self._dims = self._dims[:-1] + ('taper',) + self._dims[-1:] + # record data type (for repr and html_repr) + self._data_type = ('Fourier Coefficients' if 'taper' in self._dims + else 'Power Spectrum') + + def __eq__(self, other): + """Test equivalence of two Spectrum instances.""" + return object_diff(vars(self), vars(other)) == '' + + def __getstate__(self): + """Prepare object for serialization.""" + inst_type_str = self._get_instance_type_string() + out = dict(method=self.method, + data=self._data, + sfreq=self.sfreq, + dims=self._dims, + freqs=self.freqs, + inst_type_str=inst_type_str, + data_type=self._data_type, + info=self.info) + return out + + def __setstate__(self, state): + """Unpack from serialized format.""" + from .. import Epochs, Evoked, Info + from ..io import Raw + + self._method = state['method'] + self._data = state['data'] + self._freqs = state['freqs'] + self._dims = state['dims'] + self._sfreq = state['sfreq'] + self.info = Info(**state['info']) + self._data_type = state['data_type'] + self.preload = True + # instance type + inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked) + self._inst_type = inst_types[state['inst_type_str']] + + def __repr__(self): + """Build string representation of the Spectrum object.""" + inst_type_str = self._get_instance_type_string() + # shape & dimension names + dims = ' × '.join( + [f'{dim[0]} {dim[1]}s' for dim in zip(self.shape, self._dims)]) + freq_range = f'{self.freqs[0]:0.1f}-{self.freqs[-1]:0.1f} Hz' + return (f'<{self._data_type} (from {inst_type_str}, ' + f'{self.method} method) | {dims}, {freq_range}>') + + def _repr_html_(self, caption=None): + """Build HTML representation of the Spectrum object.""" + from ..html_templates import repr_templates_env + + inst_type_str = self._get_instance_type_string() + units = [f'{ch_type}: {unit}' + for ch_type, unit in self.units().items()] + t = repr_templates_env.get_template('spectrum.html.jinja') + t = t.render(spectrum=self, inst_type=inst_type_str, units=units) + return t + + def _check_values(self): + """Check PSD results for correct shape and bad values.""" + assert len(self._dims) == self._data.ndim + assert self._data.shape == self._shape + # negative values OK if the spectrum is really fourier coefficients + if 'taper' in self._dims: + return + # TODO: should this be more fine-grained (report "chan X in epoch Y")? + ch_dim = self._dims.index('channel') + dims = np.arange(self._data.ndim).tolist() + dims.pop(ch_dim) + # take min() across all but the channel axis + bad_value = self._data.min(axis=tuple(dims)) <= 0 + if bad_value.any(): + chs = np.array(self.ch_names)[bad_value].tolist() + s = _pl(bad_value.sum()) + warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', + UserWarning) + + def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): + # make the spectra + result = self._psd_func( + data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, + verbose=verbose) + # assign ._data (handling unaggregated multitaper output) + if method_kw.get('output', '') == 'complex': + fourier_coefs, freqs, weights = result + self._data = fourier_coefs + self._mt_weights = weights + else: + psds, freqs = result + self._data = psds + # assign properties (._data already assigned above) + self._freqs = freqs + # this is *expected* shape, it gets asserted later in _check_values() + # (and then deleted afterwards) + self._shape = (len(self.ch_names), len(self.freqs)) + # append n_welch_segments + if method_kw.get('average', '') in (None, False): + n_welch_segments = _compute_n_welch_segments(data.shape[-1], + method_kw) + self._shape += (n_welch_segments,) + # insert n_tapers + if method_kw.get('output', '') == 'complex': + self._shape = ( + self._shape[:-1] + (self._mt_weights.size,) + self._shape[-1:]) + # we don't need these anymore, and they make save/load harder + del self._picks + del self._psd_func + del self._time_mask + + def _format_units(self, unit, latex, power=True): + """Format the measurement units nicely.""" + unit = f'({unit})' if '/' in unit else unit + if power: + denom = 'Hz' + exp = r'^{2}' if latex else '²' + else: + denom = r'\sqrt{Hz}' if latex else '√(Hz)' + exp = '' + pre, post = (r'$\mathrm{', r'}$') if latex else ('', '') + return f'{pre}{unit}{exp}/{denom}{post}' + + def _get_instance_type_string(self): + """Get string representation of the originating instance type.""" + from .. import BaseEpochs, Evoked, EvokedArray + from ..io import BaseRaw + + parent_classes = self._inst_type.__bases__ + if BaseRaw in parent_classes: + inst_type_str = 'Raw' + elif BaseEpochs in parent_classes: + inst_type_str = 'Epochs' + elif self._inst_type in (Evoked, EvokedArray): + inst_type_str = 'Evoked' + else: + raise RuntimeError( + f'Unknown instance type {self._inst_type} in Spectrum') + return inst_type_str + + @property + def _detrend_picks(self): + """Provide compatibility with __iter__.""" + return list() + + @property + def ch_names(self): + return self.info['ch_names'] + + @property + def freqs(self): + return self._freqs + + @property + def method(self): + return self._method + + @property + def sfreq(self): + return self._sfreq + + @property + def shape(self): + return self._data.shape + + def copy(self): + """Return copy of the Spectrum instance. + + Returns + ------- + spectrum : instance of Spectrum + A copy of the object. + """ + return deepcopy(self) + + @fill_doc + def get_data(self, picks=None, exclude='bads', fmin=0, fmax=np.inf, + return_freqs=False): + """Get spectrum data in NumPy array format. + + Parameters + ---------- + %(picks_good_data_noref)s + %(exclude_spectrum_get_data)s + %(fmin_fmax_psd)s + return_freqs : bool + Whether to return the frequency bin values for the requested + frequency range. Default is ``True``. + + Returns + ------- + data : array + The requested data in a NumPy array. + freqs : array + The frequency values for the requested range. Only returned if + ``return_freqs`` is ``True``. + """ + picks = _picks_to_idx(self.info, picks, 'data_or_ica', exclude=exclude, + with_ref_meg=False) + fmin_idx = np.searchsorted(self.freqs, fmin) + fmax_idx = np.searchsorted(self.freqs, fmax, side='right') + freq_picks = np.arange(fmin_idx, fmax_idx) + freq_axis = self._dims.index('freq') + chan_axis = self._dims.index('channel') + # normally there's a risk of np.take reducing array dimension if there + # were only one channel or frequency selected, but `_picks_to_idx` + # always returns an array of picks, and np.arange always returns an + # array of freq bin indices, so we're safe; the result will always be + # 2D. + data = self._data.take(picks, chan_axis).take(freq_picks, freq_axis) + if return_freqs: + freqs = self._freqs[fmin_idx:fmax_idx] + return (data, freqs) + return data + + @fill_doc + def plot(self, *, picks=None, average=False, dB=True, amplitude='auto', + xscale='linear', ci='sd', ci_alpha=0.3, color='black', alpha=None, + spatial_colors=True, sphere=None, exclude='bads', axes=None, + show=True): + """%(plot_psd_doc)s. + + Parameters + ---------- + %(picks_good_data_noref)s + average : bool + Whether to average across channels before plotting. If ``True``, + interactive plotting of scalp topography is disabled, and + parameters ``ci`` and ``ci_alpha`` control the style of the + confidence band around the mean. Default is ``False``. + %(dB_spectrum_plot)s + amplitude : bool | 'auto' + Whether to plot an amplitude spectrum (``True``) or power spectrum + (``False``). If ``'auto'``, will plot a power spectrum when + ``dB=True`` and an amplitude spectrum otherwise. Default is + ``'auto'``. + %(xscale_plot_psd)s + ci : float | 'sd' | 'range' | None + Type of confidence band drawn around the mean when + ``average=True``. If ``'sd'`` the band spans ±1 standard deviation + across channels. If ``'range'`` the band spans the range across + channels at each frequency. If a :class:`float`, it indicates the + (bootstrapped) confidence interval to display, and must satisfy + ``0 < ci <= 100``. If ``None``, no band is drawn. Default is + ``sd``. + ci_alpha : float + Opacity of the confidence band. Must satisfy + ``0 <= ci_alpha <= 1``. Default is 0.3. + %(color_plot_psd)s + alpha : float | None + Opacity of the spectrum line(s). If :class:`float`, must satisfy + ``0 <= alpha <= 1``. If ``None``, opacity will be ``1`` when + ``average=True`` and ``0.1`` when ``average=False``. Default is + ``None``. + %(spatial_colors_psd)s + %(sphere_topomap_auto)s + %(exclude_spectrum_plot)s + %(axes_plot_topomap)s + %(show)s + + Returns + ------- + fig : instance of matplotlib.figure.Figure + Figure with spectra plotted in separate subplots for each channel + type. + """ + from ..viz._mpl_figure import _line_figure, _split_picks_by_type + from .multitaper import _psd_from_mt + + # arg checking + ci = _check_ci(ci) + _check_option('xscale', xscale, ('log', 'linear')) + sphere = _check_sphere(sphere, self.info) + # defaults + scalings = _handle_default('scalings', None) + titles = _handle_default('titles', None) + units = _handle_default('units', None) + if amplitude == 'auto': + estimate = 'power' if dB else 'amplitude' + else: # amplitude is boolean + estimate = 'amplitude' if amplitude else 'power' + # split picks by channel type + picks = _picks_to_idx(self.info, picks, 'data', with_ref_meg=False) + (picks_list, units_list, scalings_list, titles_list + ) = _split_picks_by_type(self, picks, units, scalings, titles) + # handle unaggregated multitaper + if hasattr(self, '_mt_weights'): + logger.info('Aggregating multitaper estimates before plotting...') + _f = partial(_psd_from_mt, weights=self._mt_weights) + # handle unaggregated Welch + elif 'segment' in self._dims: + logger.info( + 'Aggregating Welch estimates (median) before plotting...') + seg_axis = self._dims.index('segment') + _f = partial(np.nanmedian, axis=seg_axis) + else: # "normal" cases + _f = _identity_function + ch_axis = self._dims.index('channel') + psd_list = [_f(self._data.take(_p, axis=ch_axis)) for _p in picks_list] + # handle epochs + if 'epoch' in self._dims: + # XXX TODO FIXME decide how to properly aggregate across repeated + # measures (epochs) and non-repeated but correlated measures + # (channels) when calculating stddev or a CI. For across-channel + # aggregation, doi:10.1007/s10162-012-0321-8 used hotellings T**2 + # with a correction factor that estimated data rank using monte + # carlo simulations; seems like we could use our own data rank + # estimation methods to similar effect. Their exact approach used + # complex spectra though, here we've already converted to power; + # not sure if that makes an important difference? Anyway that + # aggregation would need to happen in the _plot_psd function + # though, not here... for now we just average like we always did. + + # only log message if averaging will actually have an effect + if self._data.shape[0] > 1: + logger.info('Averaging across epochs...') + # epoch axis should always be the first axis + psd_list = [_p.mean(axis=0) for _p in psd_list] + # initialize figure + fig, axes = _line_figure(self, axes, picks=picks) + # don't add ylabels & titles if figure has unexpected number of axes + make_label = len(axes) == len(fig.axes) + # Plot Frequency [Hz] xlabel only on the last axis + xlabels_list = [False] * (len(axes) - 1) + [True] + # plot + _plot_psd(self, fig, self.freqs, psd_list, picks_list, titles_list, + units_list, scalings_list, axes, make_label, color, + area_mode=ci, area_alpha=ci_alpha, dB=dB, estimate=estimate, + average=average, spatial_colors=spatial_colors, + xscale=xscale, line_alpha=alpha, + sphere=sphere, xlabels_list=xlabels_list) + fig.subplots_adjust(hspace=0.3) + plt_show(show, fig) + return fig + + @fill_doc + def plot_topo(self, *, dB=True, layout=None, color='w', + fig_facecolor='k', axis_facecolor='k', axes=None, + block=False, show=True): + """Plot power spectral density, separately for each channel. + + Parameters + ---------- + %(dB_spectrum_plot_topo)s + %(layout_spectrum_plot_topo)s + %(color_spectrum_plot_topo)s + %(fig_facecolor)s + %(axis_facecolor)s + %(axes_spectrum_plot_topo)s + %(block)s + %(show)s + + Returns + ------- + fig : instance of matplotlib.figure.Figure + Figure distributing one image per channel across sensor topography. + """ + if layout is None: + from ..channels.layout import find_layout + layout = find_layout(self.info) + + psds, freqs = self.get_data(return_freqs=True) + if dB: + psds = 10 * np.log10(psds) + y_label = 'dB' + else: + y_label = 'Power' + show_func = partial( + _plot_timeseries_unified, data=[psds], color=color, times=[freqs]) + click_func = partial( + _plot_timeseries, data=[psds], color=color, times=[freqs]) + picks = _pick_data_channels(self.info) + info = pick_info(self.info, picks) + fig = _plot_topo( + info, times=freqs, show_func=show_func, click_func=click_func, + layout=layout, axis_facecolor=axis_facecolor, + fig_facecolor=fig_facecolor, x_label='Frequency (Hz)', + unified=True, y_label=y_label, axes=axes) + plt_show(show, block=block) + return fig + + @fill_doc + def plot_topomap(self, bands=None, ch_type=None, *, normalize=False, + agg_fun=None, dB=False, # sensors=True, show_names=False, + # mask=None, mask_params=None, contours=6, + outlines='head', + sphere=None, # image_interp=_INTERPOLATION_DEFAULT, + # extrapolate=_EXTRAPOLATE_DEFAULT, + # border=_BORDER_DEFAULT, res=64, size=1, + cmap=None, vlim=(None, None), # colorbar=True, + cbar_fmt='auto', units=None, axes=None, + show=True): + """Plot scalp topography of PSD for chosen frequency bands. + + Parameters + ---------- + %(bands_psd_topo)s + %(ch_type_psd_topomap)s + %(normalize_psd_topo)s + %(agg_fun_psd_topo)s + %(dB_plot_topomap)s + %(outlines_topomap)s + %(sphere_topomap_auto)s + %(cmap_psd_topo)s + %(vlim_psd_topo_joint)s + %(cbar_fmt_psd_topo)s + %(units_topomap)s + %(axes_plot_topomap)s + %(show)s + + Returns + ------- + fig : instance of Figure + Figure showing one scalp topography per frequency band. + """ + # add after dB + # %(sensors_topomap)s + # %(show_names_topomap)s + # %(mask_evoked_topomap)s + # %(mask_params_topomap)s + # %(contours_topomap)s + # add after sphere + # %(image_interp_topomap)s + # %(extrapolate_topomap)s + # %(border_topomap)s + # %(res_topomap)s + # %(size_topomap)s + # add after vlim + # %(colorbar_topomap)s + ch_type = _get_ch_type(self, ch_type) + if units is None: + units = _handle_default('units', None) + unit = units[ch_type] if hasattr(units, 'keys') else units + scalings = _handle_default('scalings', None) + scaling = scalings[ch_type] + + picks, pos, merge_channels, names, ch_type, sphere, clip_origin = \ + _prepare_topomap_plot(self, ch_type, sphere=sphere) + outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) + + psds, freqs = self.get_data(picks=picks, return_freqs=True) + if 'epoch' in self._dims: + psds = np.mean(psds, axis=self._dims.index('epoch')) + psds *= scaling**2 + + if merge_channels: + psds, names = _merge_ch_data(psds, ch_type, names, method='mean') + + return plot_psds_topomap( + psds=psds, freqs=freqs, pos=pos, bands=bands, ch_type=ch_type, + normalize=normalize, agg_fun=agg_fun, dB=dB, # sensors=sensors, + # show_names=show_names, mask=mask, mask_params=mask_params, + # contours=contours, + outlines=outlines, sphere=sphere, + # image_interp=image_interp, extrapolate=extrapolate, + # border=border, res=res, size=size, + cmap=cmap, vlim=vlim, # colorbar=colorbar, + cbar_fmt=cbar_fmt, unit=unit, axes=axes, show=show) + + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save spectrum data to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + Path of file to save to. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.time_frequency.read_spectrum + """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, 'spectrum', ('.h5', '.hdf5')) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) + out = self.__getstate__() + write_hdf5(fname, out, overwrite=overwrite, title='mnepython') + + @verbose + def to_data_frame(self, picks=None, index=None, copy=True, + long_format=False, *, verbose=None): + """Export data in tabular structure as a pandas DataFrame. + + Channels are converted to columns in the DataFrame. By default, + an additional column "frequency" is added, unless ``index='freq'`` + (in which case frequency values form the DataFrame's index). + + Parameters + ---------- + %(picks_all)s + index : str | list of str | None + Kind of index to use for the DataFrame. If ``None``, a sequential + integer index (:class:`pandas.RangeIndex`) will be used. If a + :class:`str`, a :class:`pandas.Index`, :class:`pandas.Int64Index`, + or :class:`pandas.Float64Index` will be used (see Notes). If a list + of two or more string values, a :class:`pandas.MultiIndex` will be + used. Defaults to ``None``. + %(copy_df)s + %(long_format_df_spe)s + %(verbose)s + + Returns + ------- + %(df_return)s + + Notes + ----- + Valid values for ``index`` depend on whether the Spectrum was created + from continuous data (:class:`~mne.io.Raw`, :class:`~mne.Evoked`) or + discontinuous data (:class:`~mne.Epochs`). For continuous data, only + ``None`` or ``'freq'`` is supported. For discontinuous data, additional + valid values are ``'epoch'`` and ``'condition'``, or a :class:`list` + comprising some of the valid string values (e.g., + ``['freq', 'epoch']``). + """ + # check pandas once here, instead of in each private utils function + pd = _check_pandas_installed() # noqa + # triage for Epoch-derived or unaggregated spectra + from_epo = self._get_instance_type_string() == 'Epochs' + unagg_welch = 'segment' in self._dims + unagg_mt = 'taper' in self._dims + # arg checking + valid_index_args = ['freq'] + if from_epo: + valid_index_args += ['epoch', 'condition'] + index = _check_pandas_index_arguments(index, valid_index_args) + # get data + picks = _picks_to_idx(self.info, picks, 'all', exclude=()) + data = self.get_data(picks) + if copy: + data = data.copy() + # reshape + if unagg_mt: + data = np.moveaxis(data, self._dims.index('freq'), -2) + if from_epo: + n_epochs, n_picks, n_freqs = data.shape[:3] + else: + n_epochs, n_picks, n_freqs = (1,) + data.shape[:2] + n_segs = data.shape[-1] if unagg_mt or unagg_welch else 1 + data = np.moveaxis(data, self._dims.index('channel'), -1) + # at this point, should be ([epoch], freq, [segment/taper], channel) + data = data.reshape(n_epochs * n_freqs * n_segs, n_picks) + # prepare extra columns / multiindex + mindex = list() + default_index = list() + if from_epo: + rev_event_id = {v: k for k, v in self.event_id.items()} + _conds = [rev_event_id[k] for k in self.events[:, 2]] + conditions = np.repeat(_conds, n_freqs * n_segs) + epoch_nums = np.repeat(self.selection, n_freqs * n_segs) + mindex.extend([('condition', conditions), ('epoch', epoch_nums)]) + default_index.extend(['condition', 'epoch']) + freqs = np.tile(np.repeat(self.freqs, n_segs), n_epochs) + mindex.append(('freq', freqs)) + default_index.append('freq') + if unagg_mt or unagg_welch: + name = 'taper' if unagg_mt else 'segment' + seg_nums = np.tile(np.arange(n_segs), n_epochs * n_freqs) + mindex.append((name, seg_nums)) + default_index.append(name) + # build DataFrame + df = _build_data_frame(self, data, picks, long_format, mindex, index, + default_index=default_index) + return df + + def units(self, latex=False): + """Get the spectrum units for each channel type. + + Parameters + ---------- + latex : bool + Whether to format the unit strings as LaTeX. Default is ``False``. + + Returns + ------- + units : dict + Mapping from channel type to a string representation of the units + for that channel type. + """ + units = _handle_default('si_units', None) + power = not hasattr(self, '_mt_weights') + return {ch_type: self._format_units(units[ch_type], power=power, + latex=latex) + for ch_type in sorted(self.get_channel_types(unique=True))} + + +@fill_doc +class Spectrum(BaseSpectrum): + """Data object for spectral representations of continuous data. + + .. warning:: The preferred means of creating Spectrum objects from + continuous or averaged data is via the instance methods + :meth:`mne.io.Raw.compute_psd` or + :meth:`mne.Evoked.compute_psd`. Direct class instantiation + is not supported. + + Parameters + ---------- + inst : instance of Raw or Evoked + The data from which to compute the frequency spectrum. + %(method_psd_auto)s + ``'auto'`` (default) uses Welch's method for continuous data + and multitaper for :class:`~mne.Evoked` data. + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(reject_by_annotation_psd)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s + + Attributes + ---------- + ch_names : list + The channel names. + freqs : array + Frequencies at which the amplitude, power, or fourier coefficients + have been computed. + %(info_not_none)s + method : str + The method used to compute the spectrum ('welch' or 'multitaper'). + + See Also + -------- + EpochsSpectrum + mne.io.Raw.compute_psd + mne.Epochs.compute_psd + mne.Evoked.compute_psd + """ + + def __init__(self, inst, method, fmin, fmax, tmin, tmax, picks, + proj, reject_by_annotation, *, n_jobs, verbose, **method_kw): + from ..io import BaseRaw + + # triage reading from file + if isinstance(inst, dict): + self.__setstate__(inst) + return + # do the basic setup + super().__init__(inst, method, fmin, fmax, tmin, tmax, picks, proj, + n_jobs=n_jobs, verbose=verbose, **method_kw) + # get just the data we want + if isinstance(inst, BaseRaw): + start, stop = np.where(self._time_mask)[0][[0, -1]] + rba = 'NaN' if reject_by_annotation else None + data = inst.get_data(self._picks, start, stop + 1, + reject_by_annotation=rba) + else: # Evoked + data = inst.data[self._picks][:, self._time_mask] + # compute the spectra + self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose) + # check for correct shape and bad values + self._check_values() + del self._shape + + def __getitem__(self, item): + """Get Spectrum data. + + Parameters + ---------- + item : int | slice | array-like + Indexing is similar to a :class:`NumPy array`; see + Notes. + + Returns + ------- + %(getitem_spectrum_return)s + + Notes + ----- + Integer-, list-, and slice-based indexing is possible: + + - ``spectrum[0]`` gives all frequency bins in the first channel + - ``spectrum[:3]`` gives all frequency bins in the first 3 channels + - ``spectrum[[0, 2], 5]`` gives the value in the sixth frequency bin of + the first and third channels + - ``spectrum[(4, 7)]`` is the same as ``spectrum[4, 7]``. + + .. note:: + + Unlike :class:`~mne.io.Raw` objects (which returns a tuple of the + requested data values and the corresponding times), accessing + :class:`~mne.time_frequency.Spectrum` values via subscript does + **not** return the corresponding frequency bin values. If you need + them, use ``spectrum.freqs[freq_indices]``. + """ + from ..io import BaseRaw + self._parse_get_set_params = partial( + BaseRaw._parse_get_set_params, self) + return BaseRaw._getitem(self, item, return_times=False) + + +@fill_doc +class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): + """Data object for spectral representations of epoched data. + + .. warning:: The preferred means of creating Spectrum objects from Epochs + is via the instance method :meth:`mne.Epochs.compute_psd`. + Direct class instantiation is not supported. + + Parameters + ---------- + inst : instance of Epochs + The data from which to compute the frequency spectrum. + %(method_psd)s + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(n_jobs)s + %(verbose)s + %(method_kw_psd)s + + Attributes + ---------- + ch_names : list + The channel names. + freqs : array + Frequencies at which the amplitude, power, or fourier coefficients + have been computed. + %(info_not_none)s + method : str + The method used to compute the spectrum ('welch' or 'multitaper'). + + See Also + -------- + Spectrum + mne.io.Raw.compute_psd + mne.Epochs.compute_psd + mne.Evoked.compute_psd + """ + + def __init__(self, inst, method, fmin, fmax, tmin, tmax, picks, proj, *, + n_jobs, verbose, **method_kw): + # triage reading from file + if isinstance(inst, dict): + self.__setstate__(inst) + return + # do the basic setup + super().__init__(inst, method, fmin, fmax, tmin, tmax, picks, proj, + n_jobs=n_jobs, verbose=verbose, **method_kw) + # get just the data we want + data = inst.get_data(picks=self._picks)[:, :, self._time_mask] + # compute the spectra + self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose) + self._dims = ('epoch',) + self._dims + self._shape = (len(inst),) + self._shape + # check for correct shape and bad values + self._check_values() + del self._shape + # we need these for to_data_frame() + self.event_id = inst.event_id.copy() + self.events = inst.events.copy() + self.selection = inst.selection.copy() + # we need these for __getitem__() + self.drop_log = deepcopy(inst.drop_log) + self._metadata = inst.metadata + + def __getitem__(self, item): + """Subselect epochs from an EpochsSpectrum. + + Parameters + ---------- + item : int | slice | array-like | str + Access options are the same as for :class:`~mne.Epochs` objects, + see the docstring of :meth:`mne.Epochs.__getitem__` for + explanation. + + Returns + ------- + %(getitem_epochspectrum_return)s + """ + return super().__getitem__(item) + + def __getstate__(self): + """Prepare object for serialization.""" + out = super().__getstate__() + out.update(metadata=self._metadata, + drop_log=self.drop_log, + event_id=self.event_id, + events=self.events, + selection=self.selection) + return out + + def __setstate__(self, state): + """Unpack from serialized format.""" + super().__setstate__(state) + self._metadata = state['metadata'] + self.drop_log = state['drop_log'] + self.event_id = state['event_id'] + self.events = state['events'] + self.selection = state['selection'] + + +def read_spectrum(fname): + """Load a :class:`mne.time_frequency.Spectrum` object from disk. + + Parameters + ---------- + fname : path-like + Path to a spectrum file in HDF5 format. + + Returns + ------- + spectrum : instance of Spectrum + The loaded Spectrum object. + + See Also + -------- + mne.time_frequency.Spectrum.save + """ + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, 'path-like', 'fname') + fname = _check_fname(fname=fname, overwrite='read', must_exist=False) + # read it in + hdf5_dict = read_hdf5(fname, title='mnepython') + defaults = dict(method=None, fmin=None, fmax=None, tmin=None, tmax=None, + picks=None, proj=None, reject_by_annotation=None, + n_jobs=None, verbose=None) + Klass = (EpochsSpectrum if hdf5_dict['inst_type_str'] == 'Epochs' + else Spectrum) + return Klass(hdf5_dict, **defaults) + + +def _check_ci(ci): + ci = 'sd' if ci == 'std' else ci # be forgiving + if _is_numeric(ci): + if not (0 < ci <= 100): + raise ValueError(f'ci must satisfy 0 < ci <= 100, got {ci}') + ci /= 100. + else: + _check_option('ci', ci, [None, 'sd', 'range']) + return ci + + +def _compute_n_welch_segments(n_times, method_kw): + # get default values from psd_array_welch + _defaults = dict() + for param in ('n_fft', 'n_per_seg', 'n_overlap'): + _defaults[param] = signature(psd_array_welch).parameters[param].default + # override defaults with user-specified values + for key, val in _defaults.items(): + _defaults.update({key: method_kw.get(key, val)}) + # sanity check values / replace `None`s with real numbers + n_fft, n_per_seg, n_overlap = _check_nfft(n_times, **_defaults) + # compute expected number of segments + return n_times // (n_per_seg - n_overlap) diff --git a/mne/time_frequency/tests/test_csd.py b/mne/time_frequency/tests/test_csd.py index 2de5e3cb0d8..71448faf061 100644 --- a/mne/time_frequency/tests/test_csd.py +++ b/mne/time_frequency/tests/test_csd.py @@ -14,7 +14,7 @@ csd_array_multitaper, csd_array_morlet, tfr_morlet, csd_tfr, CrossSpectralDensity, read_csd, - pick_channels_csd, psd_multitaper) + pick_channels_csd) from mne.time_frequency.csd import _sym_mat_to_vector, _vector_to_sym_mat from mne.proj import Projection @@ -477,8 +477,8 @@ def test_csd_multitaper(): _test_csd_matrix(csd) # Test equivalence with PSD - psd, psd_freqs = psd_multitaper(epochs, fmin=1e-3, - normalization='full') # omit DC + spectrum = epochs.compute_psd(fmin=1e-3, normalization='full') # omit DC + psd, psd_freqs = spectrum.get_data(return_freqs=True) csd = csd_multitaper(epochs) assert_allclose(psd_freqs, csd.frequencies) csd = np.array([np.diag(csd.get_data(index=ii)) diff --git a/mne/time_frequency/tests/test_multitaper.py b/mne/time_frequency/tests/test_multitaper.py index 33b499af791..69066a4a01d 100644 --- a/mne/time_frequency/tests/test_multitaper.py +++ b/mne/time_frequency/tests/test_multitaper.py @@ -1,12 +1,11 @@ +# -*- coding: utf-8 -*- import numpy as np import pytest from numpy.testing import assert_array_almost_equal -from mne.time_frequency import psd_multitaper +from mne.time_frequency import psd_array_multitaper from mne.time_frequency.multitaper import dpss_windows from mne.utils import requires_nitime, _record_warnings -from mne.io import RawArray -from mne import create_info @requires_nitime @@ -35,29 +34,33 @@ def test_dpss_windows(): @requires_nitime -def test_multitaper_psd(): +@pytest.mark.parametrize('n_times', (100, 101)) +@pytest.mark.parametrize('adaptive, n_jobs', + [(False, 1), (True, 1), (True, 2)]) +def test_multitaper_psd(n_times, adaptive, n_jobs): """Test multi-taper PSD computation.""" import nitime as ni - for n_times in (100, 101): - n_channels = 5 - data = np.random.RandomState(0).randn(n_channels, n_times) - sfreq = 500 - info = create_info(n_channels, sfreq, 'eeg') - raw = RawArray(data, info) - pytest.raises(ValueError, psd_multitaper, raw, sfreq, - normalization='foo') - norm = 'full' - for adaptive, n_jobs in zip((False, True, True), (1, 1, 2)): - psd, freqs = psd_multitaper(raw, adaptive=adaptive, - n_jobs=n_jobs, - normalization=norm) - with _record_warnings(): # nitime integers - freqs_ni, psd_ni, _ = ni.algorithms.spectral.multi_taper_psd( - data, sfreq, adaptive=adaptive, jackknife=False) - assert_array_almost_equal(psd, psd_ni, decimal=4) - if n_times % 2 == 0: - # nitime's frequency definitions must be incorrect, - # they give the same values for 100 and 101 samples - assert_array_almost_equal(freqs, freqs_ni) - with pytest.raises(ValueError, match='use a value of at least'): - psd_multitaper(raw, bandwidth=4.9) + n_channels = 5 + data = np.random.default_rng(0).random((n_channels, n_times)) + sfreq = 500 + with pytest.raises(ValueError, match="Invalid value for the 'normaliza"): + psd_array_multitaper(data, sfreq, normalization='foo') + # compute with MNE + psd, freqs = psd_array_multitaper( + data, sfreq, adaptive=adaptive, n_jobs=n_jobs, normalization='full') + # compute with nitime + freqs_ni, psd_ni, _ = ni.algorithms.spectral.multi_taper_psd( + data, sfreq, adaptive=adaptive, jackknife=False) + # compare + assert_array_almost_equal(psd, psd_ni, decimal=4) + # assert_array_equal(freqs, freqs_ni) + # ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑ + # this is commented out because nitime's freq calculations differ from ours + # so there's no point checking (theirs are wrong; sometimes they return a + # freq component at exactly sfreq/2 when they shouldn't) + # nitime → np.linspace(0, sfreq / 2, n_times // 2 + 1) + # mne → scipy.fft.rfftfreq(n_times, 1. / sfreq) + + # test with bad bandwidth + with pytest.raises(ValueError, match='use a value of at least'): + psd_array_multitaper(data, sfreq, bandwidth=4.9) diff --git a/mne/time_frequency/tests/test_psd.py b/mne/time_frequency/tests/test_psd.py index 3c9131568f4..f558475faf4 100644 --- a/mne/time_frequency/tests/test_psd.py +++ b/mne/time_frequency/tests/test_psd.py @@ -1,21 +1,14 @@ import numpy as np -import os.path as op -from numpy.testing import assert_array_almost_equal, assert_allclose +from numpy.testing import (assert_array_almost_equal, assert_allclose, + assert_array_equal) from scipy.signal import welch import pytest -from mne import pick_types, Epochs, read_events -from mne.io import RawArray, read_raw_fif from mne.utils import catch_logging -from mne.time_frequency import (psd_welch, psd_array_welch, psd_multitaper, - psd_array_multitaper) +from mne.time_frequency import psd_array_welch, psd_array_multitaper from mne.time_frequency.multitaper import _psd_from_mt from mne.time_frequency.psd import _median_biases -base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data') -raw_fname = op.join(base_dir, 'test_raw.fif') -event_fname = op.join(base_dir, 'test-eve.fif') - def test_psd_nan(): """Test handling of NaN in psd_array_welch.""" @@ -41,152 +34,82 @@ def test_psd_nan(): assert 'hamming window' in log -def test_psd(): +def _make_psd_data(): + """Make noise data with sinusoids in 2 out of 7 channels.""" + rng = np.random.default_rng(0) + n_chan, n_times, sfreq = 7, 8000, 1000 + data = 0.1 * rng.random((n_chan, n_times)) + times = np.arange(n_times) / sfreq + sinusoid_freqs = [8., 50.] + chs_with_sinusoids = [0, 1] + for ix, freq in zip(chs_with_sinusoids, sinusoid_freqs): + data[ix, :] += 2 * np.sin(np.pi * 2. * freq * times) + return data, sfreq, sinusoid_freqs + + +@pytest.mark.parametrize( + 'psd_func, psd_kwargs', + [(psd_array_welch, dict(n_fft=128, window='hann')), + (psd_array_multitaper, dict(low_bias=True))]) +def test_psd(psd_func, psd_kwargs): """Tests the welch and multitaper PSD.""" - raw = read_raw_fif(raw_fname) - picks_psd = [0, 1] - - # Populate raw with sinusoids - rng = np.random.RandomState(40) - data = 0.1 * rng.randn(len(raw.ch_names), raw.n_times) - freqs_sig = [8., 50.] - for ix, freq in zip(picks_psd, freqs_sig): - data[ix, :] += 2 * np.sin(np.pi * 2. * freq * raw.times) - first_samp = raw._first_samps[0] - raw = RawArray(data, raw.info) - - tmin, tmax = 0, 20 # use a few seconds of data - fmin, fmax = 2, 70 # look at frequencies between 2 and 70Hz - n_fft = 128 - - # -- Raw -- - kws_psd = dict(tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - picks=picks_psd) # Common to all - kws_welch = dict(n_fft=n_fft) - kws_mt = dict(low_bias=True) - funcs = [(psd_welch, kws_welch), - (psd_multitaper, kws_mt)] - - for func, kws in funcs: - kws = kws.copy() - kws.update(kws_psd) - kws.update(verbose='debug') - if func is psd_welch: - kws.update(window='hann') - with catch_logging() as log: - psds, freqs = func(raw, proj=False, **kws) + data, sfreq, sinusoid_freqs = _make_psd_data() + # prepare kwargs + psd_kwargs.update(dict(fmin=2, fmax=70, verbose='debug')) + # compute PSD and test basic conformity + with catch_logging() as log: + psds, freqs = psd_func(data, sfreq, **psd_kwargs) + if psd_func is psd_array_welch: log = log.getvalue() - if func is psd_welch: - assert f'{n_fft}-point FFT on {n_fft} samples with 0 overl' in log - assert 'hann window' in log - psds_proj, freqs_proj = func(raw, proj=True, **kws) - - assert psds.shape == (len(kws['picks']), len(freqs)) - assert np.sum(freqs < 0) == 0 - assert np.sum(psds < 0) == 0 - - # Is power found where it should be - ixs_max = np.argmax(psds, axis=1) - for ixmax, ifreq in zip(ixs_max, freqs_sig): - # Find nearest frequency to the "true" freq - ixtrue = np.argmin(np.abs(ifreq - freqs)) - assert (np.abs(ixmax - ixtrue) < 2) - - # Make sure the projection doesn't change channels it shouldn't - assert_array_almost_equal(psds, psds_proj) - # Array input shouldn't work - pytest.raises(ValueError, func, raw[:3, :20][0]) - + n_fft = psd_kwargs['n_fft'] + assert f'{n_fft}-point FFT on {n_fft} samples with 0 overl' in log + assert 'hann window' in log + assert psds.shape == (data.shape[0], len(freqs)) + assert np.sum(freqs < 0) == 0 + assert np.sum(psds < 0) == 0 + # Is power found where it should be? + ixs_max = np.argmax(psds, axis=1) + for ixmax, ifreq in zip(ixs_max, sinusoid_freqs): + # Find nearest frequency to the "true" freq + ixtrue = np.argmin(np.abs(ifreq - freqs)) + assert (np.abs(ixmax - ixtrue) < 2) + + +def test_psd_array_welch_nperseg_kwarg(): + """Test n_per_seg and padding in psd_array_welch().""" + data, sfreq, _ = _make_psd_data() + # prepare kwargs + kwargs = dict(fmin=2, fmax=70, n_per_seg=128) # test n_per_seg in psd_welch (and padding) - psds1, freqs1 = psd_welch(raw, proj=False, n_fft=128, n_per_seg=128, - **kws_psd) - psds2, freqs2 = psd_welch(raw, proj=False, n_fft=256, n_per_seg=128, - **kws_psd) - assert (len(freqs1) == np.floor(len(freqs2) / 2.)) - assert (psds1.shape[-1] == np.floor(psds2.shape[-1] / 2.)) - - kws_psd.update(dict(n_fft=tmax * 1.1 * raw.info['sfreq'])) + psds1, freqs1 = psd_array_welch(data, sfreq, n_fft=128, **kwargs) + psds2, freqs2 = psd_array_welch(data, sfreq, n_fft=256, **kwargs) + assert len(freqs1) == np.floor(len(freqs2) / 2.) + assert psds1.shape[-1] == np.floor(psds2.shape[-1] / 2.) + # test bad n_fft with pytest.raises(ValueError, match='n_fft is not allowed to be > n_tim'): - psd_welch(raw, proj=False, n_per_seg=None, - **kws_psd) - kws_psd.update(dict(n_fft=128, n_per_seg=64, n_overlap=90)) + kwargs.update(n_per_seg=None) + bad_n_fft = int(data.shape[-1] * 1.1) + psd_array_welch(data, sfreq, n_fft=bad_n_fft, **kwargs) + # test bad n_overlap with pytest.raises(ValueError, match='n_overlap cannot be greater'): - psd_welch(raw, proj=False, **kws_psd) + kwargs.update(n_per_seg=64) + psd_array_welch(data, sfreq, n_fft=128, n_overlap=90, **kwargs) + # test bad fmin/fmax with pytest.raises(ValueError, match='No frequencies found'): - psd_array_welch(np.zeros((1, 1000)), 1000., fmin=10, fmax=1) + psd_array_welch(data, sfreq, fmin=10, fmax=1) - # -- psd_array_multitaper -- - psd_complex, freq, weights = psd_array_multitaper( - raw._data[:4, :500], raw.info['sfreq'], output='complex') - psd, freq = psd_array_multitaper( - raw._data[:4, :500], raw.info['sfreq'], output='power') + +def test_complex_multitaper(): + """Test complex-valued multitaper output.""" + data, sfreq, _ = _make_psd_data() + psd_complex, freq_complex, weights = psd_array_multitaper( + data[:4, :500], sfreq, output='complex') + psd, freq = psd_array_multitaper(data[:4, :500], sfreq, output='power') + assert_array_equal(freq_complex, freq) assert psd_complex.ndim == 3 # channels x tapers x freqs psd_from_complex = _psd_from_mt(psd_complex, weights) assert_allclose(psd_from_complex, psd) - # -- Epochs/Evoked -- - events = read_events(event_fname) - events[:, 0] -= first_samp - tmin, tmax, event_id = -0.5, 0.5, 1 - epochs = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks_psd, - proj=False, preload=True, baseline=None) - evoked = epochs.average() - - tmin_full, tmax_full = -1, 1 - epochs_full = Epochs(raw, events[:10], event_id, tmin_full, tmax_full, - picks=picks_psd, proj=False, preload=True, - baseline=None) - kws_psd = dict(tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - picks=picks_psd) # Common to all - funcs = [(psd_welch, kws_welch), - (psd_multitaper, kws_mt)] - - for func, kws in funcs: - kws = kws.copy() - kws.update(kws_psd) - - psds, freqs = func( - epochs[:1], proj=False, **kws) - psds_proj, freqs_proj = func( - epochs[:1], proj=True, **kws) - psds_f, freqs_f = func( - epochs_full[:1], proj=False, **kws) - - # this one will fail if you add for example 0.1 to tmin - assert_array_almost_equal(psds, psds_f, 27) - # Make sure the projection doesn't change channels it shouldn't - assert_array_almost_equal(psds, psds_proj, 27) - - # Is power found where it should be - ixs_max = np.argmax(psds.mean(0), axis=1) - for ixmax, ifreq in zip(ixs_max, freqs_sig): - # Find nearest frequency to the "true" freq - ixtrue = np.argmin(np.abs(ifreq - freqs)) - assert (np.abs(ixmax - ixtrue) < 2) - assert (psds.shape == (1, len(kws['picks']), len(freqs))) - assert (np.sum(freqs < 0) == 0) - assert (np.sum(psds < 0) == 0) - - # Array input shouldn't work - pytest.raises(ValueError, func, epochs.get_data()) - - # Testing evoked (doesn't work w/ compute_epochs_psd) - psds_ev, freqs_ev = func( - evoked, proj=False, **kws) - psds_ev_proj, freqs_ev_proj = func( - evoked, proj=True, **kws) - - # Is power found where it should be - ixs_max = np.argmax(psds_ev, axis=1) - for ixmax, ifreq in zip(ixs_max, freqs_sig): - # Find nearest frequency to the "true" freq - ixtrue = np.argmin(np.abs(ifreq - freqs_ev)) - assert (np.abs(ixmax - ixtrue) < 2) - - # Make sure the projection doesn't change channels it shouldn't - assert_array_almost_equal(psds_ev, psds_ev_proj, 27) - assert (psds_ev.shape == (len(kws['picks']), len(freqs))) - # Copied from SciPy def _median_bias(n): @@ -194,75 +117,43 @@ def _median_bias(n): return 1 + np.sum(1. / (ii_2 + 1) - 1. / ii_2) -@pytest.mark.parametrize('kind', ('raw', 'epochs', 'evoked')) -def test_psd_welch_average_kwarg(kind): - """Test `average` kwarg of psd_welch().""" - raw = read_raw_fif(raw_fname) - picks_psd = [0, 1] - - # Populate raw with sinusoids - rng = np.random.RandomState(40) - data = 0.1 * rng.randn(len(raw.ch_names), raw.n_times) - freqs_sig = [8., 50.] - for ix, freq in zip(picks_psd, freqs_sig): - data[ix, :] += 2 * np.sin(np.pi * 2. * freq * raw.times) - first_samp = raw._first_samps[0] - raw = RawArray(data, raw.info) - - tmin, tmax = -0.5, 0.5 - fmin, fmax = 0, np.inf - # make these small so that sometimes we get an odd number, sometimes an - # even number of estimates - n_fft = 64 +@pytest.mark.parametrize('crop', (False, True)) +def test_psd_welch_average_kwarg(crop): + """Test `average` kwarg of psd_array_welch().""" + data, sfreq, _ = _make_psd_data() + # prepare kwargs n_per_seg = 32 - n_overlap = 0 - - event_id = 2 - events = read_events(event_fname) - events[:, 0] -= first_samp - - kws = dict(fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, n_fft=n_fft, - n_per_seg=n_per_seg, n_overlap=n_overlap, picks=picks_psd) - - if kind == 'raw': - inst = raw - elif kind == 'epochs': - inst = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks_psd, - proj=False, preload=True, baseline=None) - elif kind == 'evoked': - inst = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks_psd, - proj=False, preload=True, baseline=None).average() - else: - raise ValueError('Unknown parametrization passed to test, check test ' - 'for typos.') - - psds_mean, freqs_mean = psd_welch(inst=inst, average='mean', **kws) - psds_median, freqs_median = psd_welch(inst=inst, average='median', **kws) - psds_unagg, freqs_unagg = psd_welch(inst=inst, average=None, **kws) - + kwargs = dict(fmin=0, fmax=np.inf, n_fft=64, n_per_seg=n_per_seg, + n_overlap=0) + # optionally crop data by n_per_seg so that we are sure to test both an + # odd number and an even number of estimates (for median bias) + if crop: + data = data[..., :-n_per_seg] + # run with average=mean/median/None + psds_mean, freqs_mean = psd_array_welch( + data, sfreq, average='mean', **kwargs) + psds_median, freqs_median = psd_array_welch( + data, sfreq, average='median', **kwargs) + psds_unagg, freqs_unagg = psd_array_welch( + data, sfreq, average=None, **kwargs) # Frequencies should be equal across all "average" types, as we feed in # the exact same data. - assert_allclose(freqs_mean, freqs_median) - assert_allclose(freqs_mean, freqs_unagg) - + assert_array_equal(freqs_mean, freqs_median) + assert_array_equal(freqs_mean, freqs_unagg) # For `average=None`, the last dimension contains the un-aggregated # segments. assert psds_mean.shape == psds_median.shape assert psds_mean.shape == psds_unagg.shape[:-1] - assert_allclose(psds_mean, psds_unagg.mean(axis=-1)) - + assert_array_equal(psds_mean, psds_unagg.mean(axis=-1)) # Compare with manual median calculation (_median_bias copied from SciPy) - bias = _median_bias(psds_unagg.shape[-1]) assert_allclose(psds_median, np.median(psds_unagg, axis=-1) / bias) - if kind == 'epochs': - want_shape = (3, 2, 33, 18) - elif kind == 'evoked': - want_shape = (2, 33, 18) - else: - assert kind == 'raw' - want_shape = (2, 33, 9) - assert psds_unagg.shape == want_shape + # check shape of unagg + n_chan, n_times = data.shape + n_freq = len(freqs_unagg) + n_segs = np.ceil(n_times / n_per_seg).astype(int) + assert n_segs % 2 == (1 if crop else 0) + assert psds_unagg.shape == (n_chan, n_freq, n_segs) @pytest.mark.parametrize('n', (2, 3, 5, 8, 12, 13, 14, 15)) @@ -279,43 +170,25 @@ def test_median_biases(n): @pytest.mark.slowtest def test_compares_psd(): """Test PSD estimation on raw for plt.psd and scipy.signal.welch.""" - raw = read_raw_fif(raw_fname) - - exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053'] # bads + 2 more - - # picks MEG gradiometers - picks = pick_types(raw.info, meg='grad', eeg=False, stim=False, - exclude=exclude)[:2] - - tmin, tmax = 0, 10 # use the first 60s of data - fmin, fmax = 2, 70 # look at frequencies between 5 and 70Hz + data, sfreq, _ = _make_psd_data() n_fft = 2048 - - # Compute psds with the new implementation using Welch - psds_welch, freqs_welch = psd_welch(raw, tmin=tmin, tmax=tmax, fmin=fmin, - fmax=fmax, proj=False, picks=picks, - n_fft=n_fft, n_jobs=None) - - # Compute psds with plt.psd - start, stop = raw.time_as_index([tmin, tmax]) - data, times = raw[picks, start:(stop + 1)] - out = [welch(d, fs=raw.info['sfreq'], nperseg=n_fft, noverlap=0) - for d in data] - freqs_mpl = out[0][0] - psds_mpl = np.array([o[1] for o in out]) - - mask = (freqs_mpl >= fmin) & (freqs_mpl <= fmax) - freqs_mpl = freqs_mpl[mask] - psds_mpl = psds_mpl[:, mask] - - assert_array_almost_equal(psds_welch, psds_mpl) - assert_array_almost_equal(freqs_welch, freqs_mpl) - - assert (psds_welch.shape == (len(picks), len(freqs_welch))) - assert (psds_mpl.shape == (len(picks), len(freqs_mpl))) - - assert (np.sum(freqs_welch < 0) == 0) - assert (np.sum(freqs_mpl < 0) == 0) - - assert (np.sum(psds_welch < 0) == 0) - assert (np.sum(psds_mpl < 0) == 0) + fmin, fmax = 2, 70 + # Compute PSD with psd_array_welch + psds_mne, freqs_mne = psd_array_welch( + data, sfreq, fmin=fmin, fmax=fmax, n_fft=n_fft) + # Compute psds with scipy.signal.welch + freqs_scipy, psds_scipy = welch( + data, fs=sfreq, nperseg=n_fft, noverlap=0, window='hamming') + # restrict to the relevant frequencies + mask = (freqs_scipy >= fmin) & (freqs_scipy <= fmax) + freqs_scipy = freqs_scipy[mask] + psds_scipy = psds_scipy[:, mask] + # make sure they match + assert_array_almost_equal(psds_mne, psds_scipy) + assert_array_equal(freqs_mne, freqs_scipy) + assert (psds_mne.shape == (data.shape[0], len(freqs_mne))) + assert (psds_scipy.shape == (data.shape[0], len(freqs_scipy))) + assert (np.sum(freqs_mne < 0) == 0) + assert (np.sum(freqs_scipy < 0) == 0) + assert (np.sum(psds_mne < 0) == 0) + assert (np.sum(psds_scipy < 0) == 0) diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py new file mode 100644 index 00000000000..a132d57ccb6 --- /dev/null +++ b/mne/time_frequency/tests/test_spectrum.py @@ -0,0 +1,184 @@ +from functools import partial + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +from mne.time_frequency import read_spectrum +from mne.time_frequency.multitaper import _psd_from_mt +from mne.utils import requires_h5py, requires_pandas + + +def test_spectrum_errors(raw): + """Test for expected errors in the .compute_psd() method.""" + with pytest.raises(ValueError, match='must not exceed ½ the sampling'): + raw.compute_psd(fmax=raw.info['sfreq'] * 0.51) + with pytest.raises(TypeError, match='unexpected keyword argument foo for'): + raw.compute_psd(foo=None) + with pytest.raises(TypeError, match='keyword arguments foo, bar for'): + raw.compute_psd(foo=None, bar=None) + + +@pytest.mark.parametrize('method', ('welch', 'multitaper')) +@pytest.mark.parametrize( + ('fmin, fmax, tmin, tmax, picks, proj, n_fft, n_overlap, n_per_seg, ' + 'average, window, bandwidth, adaptive, low_bias, normalization'), + [[0, np.inf, None, None, None, False, 256, 0, None, + 'mean', 'hamming', None, False, True, 'length'], # defaults + [5, 50, 1, 6, 'grad', True, 128, 8, 32, + 'median', 'triang', 10, True, False, 'full'] # non-defaults + ] +) +def test_spectrum_params(method, fmin, fmax, tmin, tmax, picks, proj, n_fft, + n_overlap, n_per_seg, average, window, bandwidth, + adaptive, low_bias, normalization, raw): + """Test valid parameter combinations in the .compute_psd() method.""" + kwargs = dict(method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, + picks=picks, proj=proj) + if method == 'welch': + kwargs.update(n_fft=n_fft, n_overlap=n_overlap, n_per_seg=n_per_seg, + average=average, window=window) + else: + kwargs.update(bandwidth=bandwidth, adaptive=adaptive, + low_bias=low_bias, normalization=normalization) + raw.compute_psd(**kwargs) + + +@requires_h5py +@pytest.mark.parametrize('inst', ('raw', 'epochs', 'evoked')) +def test_spectrum_io(inst, tmp_path, request, evoked): + """Test save/load of spectrum objects.""" + fname = tmp_path / f'{inst}-spectrum.h5' + # ↓ XXX workaround: + # ↓ parametrized fixtures are not accessible via request.getfixturevalue + # ↓ https://github.com/pytest-dev/pytest/issues/4666#issuecomment-456593913 + inst = evoked if inst == 'evoked' else request.getfixturevalue(inst) + orig = inst.compute_psd() + orig.save(fname) + loaded = read_spectrum(fname) + assert orig == loaded + + +def test_spectrum_copy(raw): + """Test copying Spectrum objects.""" + spect = raw.compute_psd() + spect_copy = spect.copy() + assert spect == spect_copy + assert id(spect) != id(spect_copy) + spect_copy._freqs = None + assert spect.freqs is not None + + +def test_spectrum_getitem_raw(raw): + """Test Spectrum.__getitem__ for Raw-derived spectra.""" + spect = raw.compute_psd() + want = spect.get_data(slice(1, 3), fmax=7) + freq_idx = np.searchsorted(spect.freqs, 7) + got = spect[1:3, :freq_idx] + assert_array_equal(want, got) + + +def test_spectrum_getitem_epochs(epochs): + """Test Spectrum.__getitem__ for Epochs-derived spectra.""" + spect = epochs.compute_psd() + # testing data has just one epoch, its event_id label is "1" + want = spect.get_data() + got = spect['1'].get_data() + assert_array_equal(want, got) + + +def _agg_helper(df, weights, group_cols): + """Aggregate complex multitaper spectrum after conversion to DataFrame.""" + from pandas import Series + + unagged_columns = df[group_cols].iloc[0].values.tolist() + x_mt = df.drop(columns=group_cols).values[np.newaxis].T + psd = _psd_from_mt(x_mt, weights) + psd = np.atleast_1d(np.squeeze(psd)).tolist() + _df = dict(zip(df.columns, unagged_columns + psd)) + return Series(_df) + + +@requires_pandas +@pytest.mark.parametrize('long_format', (False, True)) +@pytest.mark.parametrize('method', ('welch', 'multitaper')) +def test_unaggregated_spectrum_to_data_frame(raw, long_format, method): + """Test converting complex multitaper spectra to data frame.""" + from pandas.testing import assert_frame_equal + + # aggregated spectrum → dataframe + orig_df = (raw.compute_psd(method=method) + .to_data_frame(long_format=long_format)) + # unaggregated welch or complex multitaper → + # aggregate w/ pandas (to make sure we did reshaping right) + kwargs = {'average': False} if method == 'welch' else {'output': 'complex'} + spectrum = raw.compute_psd(method=method, **kwargs) + df = spectrum.to_data_frame(long_format=long_format) + group_by = ['freq'] + drop_cols = ['segment'] if method == 'welch' else ['taper'] + if long_format: + group_by.append('channel') + drop_cols.append('ch_type') + orig_df.drop(columns='ch_type', inplace=True) + # only do a couple freq bins, otherwise test takes forever for multitaper + subset = partial(np.isin, test_elements=spectrum.freqs[:2]) + df = df.loc[subset(df['freq'])] + orig_df = orig_df.loc[subset(orig_df['freq'])] + # aggregate + gb = (df.drop(columns=drop_cols) + .groupby(group_by, sort=False, as_index=False)) + if method == 'welch': + agg_df = gb.aggregate(np.nanmean) + else: + agg_df = gb.apply(_agg_helper, spectrum._mt_weights, group_by) + # even with check_categorical=False, we know that the *data* matches; + # what may differ is the order of the "levels" in the *metadata* for the + # channel name column + assert_frame_equal(agg_df, orig_df, check_categorical=False) + + +@requires_pandas +@pytest.mark.parametrize('inst', ('raw', 'epochs', 'evoked')) +def test_spectrum_to_data_frame(inst, request, evoked): + """Test the to_data_frame method for Spectrum.""" + from pandas.testing import assert_frame_equal + + # setup + is_epochs = inst == 'epochs' + # ↓ XXX workaround: + # ↓ parametrized fixtures are not accessible via request.getfixturevalue + # ↓ https://github.com/pytest-dev/pytest/issues/4666#issuecomment-456593913 + inst = evoked if inst == 'evoked' else request.getfixturevalue(inst) + extra_dim = () if is_epochs else (1,) + extra_cols = ['freq', 'condition', 'epoch'] if is_epochs else ['freq'] + # compute PSD + spectrum = inst.compute_psd() + n_epo, n_chan, n_freq = extra_dim + spectrum.get_data().shape + # test wide format + df_wide = spectrum.to_data_frame() + n_row, n_col = df_wide.shape + assert n_row == n_freq + assert n_col == n_chan + len(extra_cols) + assert set(spectrum.ch_names + extra_cols) == set(df_wide.columns) + # test long format + df_long = spectrum.to_data_frame(long_format=True) + n_row, n_col = df_long.shape + assert n_row == n_epo * n_freq * n_chan + base_cols = ['channel', 'ch_type', 'value'] + assert n_col == len(base_cols + extra_cols) + assert set(base_cols + extra_cols) == set(df_long.columns) + # test index + index = extra_cols[-2:] # ['freq'] or ['condition', 'epoch'] + df = spectrum.to_data_frame(index=index) + if is_epochs: + index_tuple = (list(spectrum.event_id)[0], # condition + spectrum.selection[0]) # epoch number + subset = df.loc[index_tuple] + assert subset.shape == (n_freq, n_chan + 1) # + 1 is the freq column + with pytest.raises(ValueError, match='"time" is not a valid option'): + spectrum.to_data_frame(index='time') + # test picks + picks = [0, 1] + _pick_first = spectrum.pick(picks).to_data_frame() + _pick_last = spectrum.to_data_frame(picks=picks) + assert_frame_equal(_pick_first, _pick_last) diff --git a/mne/utils/check.py b/mne/utils/check.py index aa6e3ce0c61..59b01485f1a 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -272,8 +272,9 @@ def _check_preload(inst, msg): from ..epochs import BaseEpochs from ..evoked import Evoked from ..time_frequency import _BaseTFR + from ..time_frequency.spectrum import BaseSpectrum - if isinstance(inst, (_BaseTFR, Evoked)): + if isinstance(inst, (_BaseTFR, Evoked, BaseSpectrum)): pass else: name = "epochs" if isinstance(inst, BaseEpochs) else 'raw' diff --git a/mne/utils/dataframe.py b/mne/utils/dataframe.py index 16956bb2b86..400058e0e83 100644 --- a/mne/utils/dataframe.py +++ b/mne/utils/dataframe.py @@ -97,6 +97,6 @@ def _build_data_frame(inst, data, picks, long_format, mindex, index, df = _inplace(df, 'set_index', keys=index) # convert channel/vertex/ch_type columns to factors to_factor = [c for c in df.columns.tolist() - if c not in ('time', 'value')] + if c not in ('freq', 'time', 'value')] _set_pandas_dtype(df, to_factor, 'category') return df diff --git a/mne/utils/docs.py b/mne/utils/docs.py index bfb17cf27e1..61aefeed01e 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -4,7 +4,6 @@ # # License: BSD-3-Clause -from copy import deepcopy import inspect import os import os.path as op @@ -12,11 +11,12 @@ import sys import warnings import webbrowser +from copy import deepcopy from decorator import FunctionMaker -from ._bunch import BunchConst from ..defaults import HEAD_SIZE_DEFAULT +from ._bunch import BunchConst def _reflow_param_docstring(docstring, has_first_line=True, width=75): @@ -200,12 +200,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict['applyfun_summary_raw'] = \ applyfun_summary.format('raw', applyfun_preload) -docdict['area_alpha_plot_psd'] = """ +docdict['area_alpha_plot_psd'] = """\ area_alpha : float Alpha for the area. """ -docdict['area_mode_plot_psd'] = """ +docdict['area_mode_plot_psd'] = """\ area_mode : str | None Mode for plotting area. If 'std', the mean +/- 1 STD (across channels) will be plotted. If 'range', the min and max (across channels) will be @@ -220,7 +220,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Freesurfer subject directory. """ -docdict['average_plot_psd'] = """ +docdict['average_plot_psd'] = """\ average : bool If False, the PSDs of all channels is displayed. No averaging is done and parameters area_mode and area_alpha are ignored. When @@ -228,7 +228,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): drag) to plot a topomap. """ -docdict['average_psd'] = """ +docdict['average_psd'] = """\ average : str | None How to average the segments. If ``mean`` (default), calculate the arithmetic mean. If ``median``, calculate the median, corrected for @@ -267,6 +267,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): 'match the number of ``times`` provided (unless ``times`` is ``None``)') docdict['axes_plot_topomap'] = _axes.format( 'axes', 'match the length of ``bands``') +docdict['axes_spectrum_plot'] = _axes.format( + 'axes', _ch_types_present.format(':class:`~mne.time_frequency.Spectrum`')) +docdict['axes_spectrum_plot_topo'] = _axes.format( + 'axes', + 'be length 1 (for efficiency, subplots for each channel are simulated ' + 'within a single :class:`~matplotlib.axes.Axes` object)') + +docdict['axis_facecolor'] = """\ +axis_facecolor : str | tuple + A matplotlib-compatible color to use for the axis background. + Defaults to black. +""" docdict['azimuth'] = """ azimuth : float @@ -369,6 +381,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ +docdict['block'] = """\ +block : bool + Whether to halt program execution until the figure is closed. + May not work on all systems / platforms. Defaults to ``False``. +""" + docdict['border_topomap'] = """ border : float | 'mean' Value to extrapolate to on the topomap borders. If ``'mean'`` (default), @@ -472,13 +490,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ch_names=[[], ['MEG0111', 'MEG2563'], ['MEG1443']]) """ -docdict['ch_type_evoked_topomap'] = """ +_ch_type_topomap = """\ ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | None - The channel type to plot. For 'grad', the gradiometers are collected in - pairs and the RMS for each pair is plotted. - If None, then channels are chosen in the order given above. + The channel type to plot. For ``'grad'``, the gradiometers are + collected in pairs and the {} for each pair is plotted. If + ``None`` the first available channel type from order shown above is + used. Defaults to ``None``. """ +docdict['ch_type_psd_topomap'] = _ch_type_topomap.format('mean') + docdict['ch_type_set_eeg_reference'] = """ ch_type : list of str | str The name of the channel type to apply the reference to. @@ -489,13 +510,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.19 """ -docdict['ch_type_topomap'] = """ -ch_type : str - The channel type being plotted. Determines the ``'auto'`` - extrapolation mode. - - .. versionadded:: 0.21 -""" +docdict['ch_type_topomap'] = _ch_type_topomap.format('RMS') chwise = """ channel_wise : bool @@ -608,12 +623,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): A list of anything matplotlib accepts: string, RGB, hex, etc. """ -docdict['color_plot_psd'] = """ +docdict['color_plot_psd'] = """\ color : str | tuple A matplotlib-compatible color to use. Has no effect when spatial_colors=True. """ +docdict['color_spectrum_plot_topo'] = """\ +color : str | tuple + A matplotlib-compatible color to use for the curves. Defaults to + white. +""" + docdict['colorbar_topomap'] = """ colorbar : bool Plot a colorbar in the rightmost column of the figure. @@ -742,6 +763,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict['dB_plot_topomap'] = _dB.format( ' following the application of ``agg_fun``', ' Ignored if ``normalize=True``.') +docdict['dB_spectrum_plot'] = _dB.format('', '') +docdict['dB_spectrum_plot_topo'] = _dB.format( + '', ' Ignored if ``normalize=True``.') docdict['daysback_anonymize_info'] = """ daysback : int | None @@ -755,7 +779,6 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): dbs : bool If True (default), show DBS (deep brain stimulation) electrodes. """ - docdict['decim'] = """ decim : int Factor by which to subsample the data. @@ -953,7 +976,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): time are included. Defaults to ``-0.2`` and ``0.5``, respectively. """ -docdict['estimate_plot_psd'] = """ +docdict['estimate_plot_psd'] = """\ estimate : str, {'auto', 'power', 'amplitude'} Can be "power" for power spectral density (PSD), "amplitude" for amplitude spectrum density (ASD), or "auto" (default), which uses @@ -1030,6 +1053,17 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (below the nasion) and positive Y values (in front of the LPA/RPA). """ +_exclude_spectrum = """\ +exclude : list of str | 'bads' + Channel names to exclude{}. If ``'bads'``, channels + in ``spectrum.info['bads']`` are excluded; pass an empty list to + plot all channels (including "bad" channels, if any). +""" + +docdict['exclude_spectrum_get_data'] = _exclude_spectrum.format('') +docdict['exclude_spectrum_plot'] = _exclude_spectrum.format( + ' from being drawn') + docdict['export_edf_note'] = """ For EDF exports, only channels measured in Volts are allowed; in MNE-Python this means channel types 'eeg', 'ecog', 'seeg', 'emg', 'eog', 'ecg', 'dbs', @@ -1188,6 +1222,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): and if absent, falls back to ``'estimated'``. """ +docdict['fig_facecolor'] = """\ +fig_facecolor : str | tuple + A matplotlib-compatible color to use for the figure background. + Defaults to black. +""" + docdict['filter_length'] = """ filter_length : str | int Length of the FIR filter to use (if applicable): @@ -1371,6 +1411,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (False, default). """ +_getitem_base = """\ +data : ndarray + The selected spectral data. Shape will be + ``({}n_channels, n_freqs)`` for normal power spectra, + ``({}n_channels, n_freqs, n_segments)`` for unaggregated + Welch estimates, or ``({}n_channels, n_tapers, n_freqs)`` + for unaggregated multitaper estimates. +""" +_fill_epochs = ['n_epochs, '] * 3 +docdict['getitem_epochspectrum_return'] = _getitem_base.format(*_fill_epochs) +docdict['getitem_spectrum_return'] = _getitem_base.format('', '', '') + docdict['group_by_browse'] = """ group_by : str How to group channels. ``'type'`` groups by channel type, @@ -1545,8 +1597,6 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (depending on the value of ``time_format``). {} """ -docdict['index_df'] = _index_df_base - datetime = ':class:`pandas.DatetimeIndex`, ' multiindex = ('If a list of two or more string values, a ' ':class:`pandas.MultiIndex` will be created. ') @@ -1699,7 +1749,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Support for volume source estimates. """ -docdict['line_alpha_plot_psd'] = """ +docdict['layout_spectrum_plot_topo'] = """\ +layout : instance of Layout | None + Layout instance specifying sensor positions (does not need to be + specified for Neuromag data). If ``None`` (default), the layout is + inferred from the data. +""" + +docdict['line_alpha_plot_psd'] = """\ line_alpha : float | None Alpha for the PSD line. Can be None (default) to use 1.0 when ``average=True`` and 0.1 when ``average=False``. @@ -1708,18 +1765,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): _long_format_df_base = """ long_format : bool If True, the DataFrame is returned in long format where each row is one - observation of the signal at a unique combination of time point{}. + observation of the signal at a unique combination of {}. {}Defaults to ``False``. """ ch_type = ('For convenience, a ``ch_type`` column is added to facilitate ' 'subsetting the resulting DataFrame. ') -raw = (' and channel', ch_type) -epo = (', channel, epoch number, and condition', ch_type) -stc = (' and vertex', '') +raw = ('time point and channel', ch_type) +epo = ('time point, channel, epoch number, and condition', ch_type) +stc = ('time point and vertex', '') +spe = ('frequency and channel', ch_type) docdict['long_format_df_epo'] = _long_format_df_base.format(*epo) docdict['long_format_df_raw'] = _long_format_df_base.format(*raw) +docdict['long_format_df_spe'] = _long_format_df_base.format(*spe) docdict['long_format_df_stc'] = _long_format_df_base.format(*stc) docdict['loose'] = """ @@ -1865,6 +1924,30 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): forward-backward filtering (via filtfilt). """ +docdict['method_kw_psd'] = """\ +**method_kw + Additional keyword arguments passed to the spectral estimation + function (e.g., ``n_fft, n_overlap, n_per_seg, average, window`` + for Welch method, or + ``bandwidth, adaptive, low_bias, normalization`` for multitaper + method). See :func:`~mne.time_frequency.psd_array_welch` and + :func:`~mne.time_frequency.psd_array_multitaper` for details. +""" + +_method_psd = """\ +method : 'welch' | 'multitaper'{} + Spectral estimation method. ``'welch'`` uses Welch's method + :footcite:`Welch1967`, ``'multitaper'`` uses DPSS tapers + :footcite:`Slepian1978`.{} +""" +docdict['method_plot_psd_auto'] = _method_psd.format( + " | 'auto'", + (" ``'auto'`` (default) uses Welch's method for continuous data and " + "multitaper for :class:`~mne.Epochs` or :class:`~mne.Evoked` data.") +) +docdict['method_psd'] = _method_psd.format('', '') +docdict['method_psd_auto'] = _method_psd.format(" | 'auto'", '') + docdict['mode_eltc'] = """ mode : str Extraction mode, see Notes. @@ -1923,7 +2006,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Default n_comp=1. """ -docdict['n_jobs'] = """ +docdict['n_jobs'] = """\ n_jobs : int | None The number of jobs to run in parallel. If ``-1``, it is set to the number of CPU cores. Requires the :mod:`joblib` package. @@ -2038,6 +2121,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): of ``mne-qt-browser``. """ +_notes_plot_psd = """\ +This {} exists to support legacy code; for new code the preferred +idiom is ``inst.compute_psd().plot()`` (where ``inst`` is an instance +of :class:`~mne.io.Raw`, :class:`~mne.Epochs`, or :class:`~mne.Evoked`). +""" + +docdict['notes_plot_*_psd_func'] = _notes_plot_psd.format('function') +docdict['notes_plot_psd_meth'] = _notes_plot_psd.format('method') + docdict['notes_tmax_included_by_default'] = """ Unlike Python slices, MNE time intervals by default include **both** their end points; ``crop(tmin, tmax)`` returns the interval @@ -2460,14 +2552,17 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): the SDR step. """ -docdict['plot_psd_doc'] = """ -Plot the power spectral density across channels. +docdict["plot_psd_doc"] = """\ +Plot power or amplitude spectra. -Different channel types are drawn in sub-plots. When the data have been +Separate plots are drawn for each channel type. When the data have been processed with a bandpass, lowpass or highpass filter, dashed lines (╎) -indicate the boundaries of the filter. The line noise frequency is -also indicated with a dashed line (⋮) +indicate the boundaries of the filter. The line noise frequency is also +indicated with a dashed line (⋮). If ``average=False``, the plot will +be interactive, and click-dragging on the spectrum will generate a +scalp topography plot for the chosen frequency range in a new figure """ +# lack of trailing . is intentional; it must be in actual docstring ↑↑↑ (D400) docdict['precompute'] = """ precompute : bool | str @@ -2529,6 +2624,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Support for 'reconstruct' was added. """ +docdict['proj_psd'] = """\ +proj : bool + Whether to apply SSP projection vectors before spectral estimation. + Default is ``False``. +""" + docdict['proj_topomap_kwargs'] = """ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None Colormap to use. If tuple, the first value indicates the colormap to @@ -2964,7 +3065,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. footbibliography:: """ -docdict['show'] = """ +docdict['show'] = """\ show : bool Show the figure if ``True``. """ @@ -3030,9 +3131,10 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): The smoothing factor to be applied. Default 0 is no smoothing. """ -docdict['spatial_colors_plot_psd'] = """ +docdict['spatial_colors_psd'] = """\ spatial_colors : bool - Whether to use spatial colors. Only used when ``average=False``. + Whether to color spectrum lines by channel location. Ignored if + ``average=True``. """ _sphere_header = ( @@ -3418,6 +3520,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Start time of the raw data to use in seconds (must be >= 0). """ +docdict['tmin_tmax_psd'] = """\ +tmin, tmax : float | None + First and last times to include, in seconds. ``None`` uses the first or + last time present in the data. Default is ``tmin=None, tmax=None`` (all + times). +""" + docdict['tol_kind_rank'] = """ tol_kind : str Can be: "absolute" (default) or "relative". Only used if ``tol`` is a @@ -3567,7 +3676,8 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Control verbosity of the logging output. If ``None``, use the default verbosity level. See the :ref:`logging documentation ` and :func:`mne.verbose` for details. Should only be passed as a keyword - argument.""" + argument. +""" docdict['vertices_volume'] = """ vertices : list of array of int @@ -3657,7 +3767,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ -docdict['window_psd'] = """ +docdict['window_psd'] = """\ window : str | float | tuple Windowing function to use. See :func:`scipy.signal.get_window`. """ @@ -3671,9 +3781,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # X -docdict['xscale_plot_psd'] = """ -xscale : str - Can be 'linear' (default) or 'log'. +docdict['xscale_plot_psd'] = """\ +xscale : 'linear' | 'log' + Scale of the frequency axis. Default is ``'linear'``. """ # %% diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index 8333036605d..edc245c5636 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -187,7 +187,7 @@ def _getitem(self, item, reason='IGNORED', copy=True, drop_event_id=True, subset of epochs (and optionally array with kept epoch indices) """ data = self._data - del self._data + self._data = None inst = self.copy() if copy else self self._data = inst._data = data del self diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 40460ccf592..3661ba4d006 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -6,6 +6,7 @@ # License: BSD-3-Clause import hashlib +import inspect import numbers import operator import os @@ -750,8 +751,8 @@ def object_diff(a, b, pre='', *, allclose=False): Parameters ---------- a : object - Currently supported: dict, list, tuple, ndarray, int, str, bytes, - float, StringIO, BytesIO. + Currently supported: class, dict, list, tuple, ndarray, + int, str, bytes, float, StringIO, BytesIO. b : object Must be same type as ``a``. pre : str @@ -772,8 +773,11 @@ def object_diff(a, b, pre='', *, allclose=False): if isinstance(a, sub) and isinstance(b, sub): break else: - return pre + ' type mismatch (%s, %s)\n' % (type(a), type(b)) - if isinstance(a, dict): + return (f'{pre} type mismatch ({type(a)}, {type(b)})\n') + if inspect.isclass(a): + if inspect.isclass(b) and a != b: + return f'{pre} class mismatch ({a}, {b})\n' + elif isinstance(a, dict): k1s = _sort_keys(a) k2s = _sort_keys(b) m1 = set(k2s) - set(k1s) diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index a4582c52a47..754859846a0 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -936,16 +936,9 @@ def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, tmin=None, tmax=None, ---------- epochs : instance of Epochs The epochs object. - fmin : float - Start frequency to consider. - fmax : float - End frequency to consider. - tmin : float | None - Start time to consider. - tmax : float | None - End time to consider. - proj : bool - Apply projection. + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(proj_psd)s bandwidth : float The bandwidth of the multi taper windowing function in Hz. The default value is a window half-bandwidth of 4. @@ -968,7 +961,7 @@ def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, tmin=None, tmax=None, %(n_jobs)s %(average_plot_psd)s %(line_alpha_plot_psd)s - %(spatial_colors_plot_psd)s + %(spatial_colors_psd)s %(sphere_topomap_auto)s exclude : list of str | 'bads' Channels names to exclude from being shown. If 'bads', the bad channels @@ -982,19 +975,19 @@ def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, tmin=None, tmax=None, ------- fig : instance of Figure Figure with frequency spectra of the data channels. + + Notes + ----- + %(notes_plot_*_psd_func)s """ - from ._mpl_figure import _psd_figure - - # generate figure - # epochs always use multitaper, not Welch, so no need to allow "window" - # param above - fig = _psd_figure( - inst=epochs, proj=proj, picks=picks, axes=ax, tmin=tmin, tmax=tmax, - fmin=fmin, fmax=fmax, sphere=sphere, xscale=xscale, dB=dB, - average=average, estimate=estimate, area_mode=area_mode, - line_alpha=line_alpha, area_alpha=area_alpha, color=color, - spatial_colors=spatial_colors, n_jobs=n_jobs, bandwidth=bandwidth, - adaptive=adaptive, low_bias=low_bias, normalization=normalization, - window='hamming', exclude=exclude) - plt_show(show) + fig = epochs.plot_psd( + fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks, + proj=proj, method='multitaper', + ax=ax, color=color, xscale=xscale, area_mode=area_mode, + area_alpha=area_alpha, dB=dB, estimate=estimate, show=show, + line_alpha=line_alpha, spatial_colors=spatial_colors, sphere=sphere, + exclude=exclude, n_jobs=n_jobs, average=average, verbose=verbose, + # these are **method_kw: + window='hamming', bandwidth=bandwidth, adaptive=adaptive, + low_bias=low_bias, normalization=normalization) return fig diff --git a/mne/viz/ica.py b/mne/viz/ica.py index e782df82ba4..7aee10cdfe7 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -23,7 +23,6 @@ from ..defaults import _handle_default from ..io.meas_info import create_info from ..io.pick import pick_types, _picks_to_idx -from ..time_frequency.psd import psd_multitaper from ..utils import _reject_data_segments, verbose @@ -465,7 +464,13 @@ def _fast_plot_ica_properties(ica, inst, picks=None, axes=None, dB=True, if 'fmax' not in psd_args: psd_args['fmax'] = min(lp * 1.25, Nyquist) plot_lowpass_edge = lp < Nyquist and (psd_args['fmax'] > lp) - psds, freqs = psd_multitaper(epochs_src, picks=picks, **psd_args) + spectrum = epochs_src.compute_psd(picks=picks, **psd_args) + # we've already restricted picks ↑↑↑↑↑↑↑↑↑↑↑ + # in the spectrum object, so here we do picks=all ↓↓↓↓↓↓↓↓↓↓↓ + psds, freqs = spectrum.get_data(return_freqs=True, picks='all', exclude=[]) + # we also pass exclude=[] so that when this is called by right-clicking in + # a plot_sources() window on an ICA component name that has been marked as + # bad, we can still get a plot of it. def set_title_and_labels(ax, title, xlab, ylab): if title: diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 8839654912f..4b7a79c4904 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -6,19 +6,16 @@ # # License: Simplified BSD -from functools import partial from collections import OrderedDict import numpy as np from ..annotations import _annotations_starts_stops from ..filter import create_filter -from ..io.pick import pick_types, _pick_data_channels, pick_info, pick_channels +from ..io.pick import pick_types, pick_channels from ..utils import verbose, _validate_type, _check_option -from ..time_frequency import psd_welch from ..defaults import _handle_default -from .topo import _plot_topo, _plot_timeseries, _plot_timeseries_unified -from .utils import (plt_show, _compute_scalings, _handle_decim, _check_cov, +from .utils import (_compute_scalings, _handle_decim, _check_cov, _shorten_path_from_middle, _handle_precompute, _get_channel_plotting_order, _make_event_color_dict) @@ -374,20 +371,12 @@ def plot_raw_psd(raw, fmin=0, fmax=np.inf, tmin=None, tmax=None, proj=False, ---------- raw : instance of Raw The raw object. - fmin : float - Start frequency to consider. - fmax : float - End frequency to consider. - tmin : float | None - Start time to consider. - tmax : float | None - End time to consider. - proj : bool - Apply projection. + %(fmin_fmax_psd)s + %(tmin_tmax_psd)s + %(proj_psd)s n_fft : int | None - Number of points to use in Welch FFT calculations. - Default is None, which uses the minimum of 2048 and the - number of time points. + Number of points to use in Welch FFT calculations. Default is ``None``, + which uses the minimum of 2048 and the number of time points. n_overlap : int The number of points of overlap between blocks. The default value is 0 (no overlap). @@ -404,7 +393,7 @@ def plot_raw_psd(raw, fmin=0, fmax=np.inf, tmin=None, tmax=None, proj=False, %(n_jobs)s %(average_plot_psd)s %(line_alpha_plot_psd)s - %(spatial_colors_plot_psd)s + %(spatial_colors_psd)s %(sphere_topomap_auto)s %(window_psd)s @@ -421,57 +410,46 @@ def plot_raw_psd(raw, fmin=0, fmax=np.inf, tmin=None, tmax=None, proj=False, ------- fig : instance of Figure Figure with frequency spectra of the data channels. + + Notes + ----- + %(notes_plot_*_psd_func)s """ - from ._mpl_figure import _psd_figure - # handle FFT - if n_fft is None: - if tmax is None or not np.isfinite(tmax): - tmax = raw.times[-1] - tmin = 0. if tmin is None else tmin - n_fft = min(np.diff(raw.time_as_index([tmin, tmax]))[0] + 1, 2048) - # generate figure - fig = _psd_figure( - inst=raw, proj=proj, picks=picks, axes=ax, tmin=tmin, tmax=tmax, - fmin=fmin, fmax=fmax, sphere=sphere, xscale=xscale, dB=dB, - average=average, estimate=estimate, area_mode=area_mode, - line_alpha=line_alpha, area_alpha=area_alpha, color=color, - spatial_colors=spatial_colors, n_jobs=n_jobs, n_fft=n_fft, - n_overlap=n_overlap, reject_by_annotation=reject_by_annotation, - window=window, exclude=exclude) - plt_show(show) + fig = raw.plot_psd( + fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks, + proj=proj, reject_by_annotation=reject_by_annotation, method='welch', + ax=ax, color=color, xscale=xscale, area_mode=area_mode, + area_alpha=area_alpha, dB=dB, estimate=estimate, show=show, + line_alpha=line_alpha, spatial_colors=spatial_colors, sphere=sphere, + exclude=exclude, n_jobs=n_jobs, average=average, verbose=verbose, + n_fft=n_fft, n_overlap=n_overlap, window=window) return fig @verbose def plot_raw_psd_topo(raw, tmin=0., tmax=None, fmin=0., fmax=100., proj=False, - n_fft=2048, n_overlap=0, layout=None, color='w', - fig_facecolor='k', axis_facecolor='k', dB=True, - show=True, block=False, n_jobs=None, axes=None, + *, n_fft=2048, n_overlap=0, dB=True, layout=None, + color='w', fig_facecolor='k', axis_facecolor='k', + axes=None, block=False, show=True, n_jobs=None, verbose=None): - """Plot channel-wise frequency spectra as topography. + """Plot power spectral density, separately for each channel. Parameters ---------- raw : instance of io.Raw The raw instance to use. - tmin : float - Start time for calculations. Defaults to zero. - tmax : float | None - End time for calculations. If None (default), the end of data is used. - fmin : float - Start frequency to consider. Defaults to zero. - fmax : float - End frequency to consider. Defaults to 100. - proj : bool - Apply projection. Defaults to False. + %(tmin_tmax_psd)s + %(fmin_fmax_psd_topo)s + %(proj_psd)s n_fft : int Number of points to use in Welch FFT calculations. Defaults to 2048. n_overlap : int The number of points of overlap between blocks. Defaults to 0 (no overlap). + %(dB_spectrum_plot_topo)s layout : instance of Layout | None Layout instance specifying sensor positions (does not need to be - specified for Neuromag data). If None (default), the correct layout is + specified for Neuromag data). If ``None`` (default), the layout is inferred from the data. color : str | tuple A matplotlib-compatible color to use for the curves. Defaults to white. @@ -481,16 +459,12 @@ def plot_raw_psd_topo(raw, tmin=0., tmax=None, fmin=0., fmax=100., proj=False, axis_facecolor : str | tuple A matplotlib-compatible color to use for the axis background. Defaults to black. - dB : bool - If True, transform data to decibels. Defaults to True. - show : bool - Show figure if True. Defaults to True. + %(axes_spectrum_plot_topo)s block : bool Whether to halt program execution until the figure is closed. May not work on all systems / platforms. Defaults to False. + %(show)s %(n_jobs)s - axes : instance of matplotlib Axes | None - Axes to plot into. If None, axes will be created. %(verbose)s Returns @@ -498,36 +472,11 @@ def plot_raw_psd_topo(raw, tmin=0., tmax=None, fmin=0., fmax=100., proj=False, fig : instance of matplotlib.figure.Figure Figure distributing one image per channel across sensor topography. """ - if layout is None: - from ..channels.layout import find_layout - layout = find_layout(raw.info) - - psds, freqs = psd_welch(raw, tmin=tmin, tmax=tmax, fmin=fmin, - fmax=fmax, proj=proj, n_fft=n_fft, - n_overlap=n_overlap, n_jobs=n_jobs) - if dB: - psds = 10 * np.log10(psds) - y_label = 'dB' - else: - y_label = 'Power' - show_func = partial(_plot_timeseries_unified, data=[psds], color=color, - times=[freqs]) - click_func = partial(_plot_timeseries, data=[psds], color=color, - times=[freqs]) - picks = _pick_data_channels(raw.info) - info = pick_info(raw.info, picks) - - fig = _plot_topo(info, times=freqs, show_func=show_func, - click_func=click_func, layout=layout, - axis_facecolor=axis_facecolor, - fig_facecolor=fig_facecolor, x_label='Frequency (Hz)', - unified=True, y_label=y_label, axes=axes) - - try: - plt_show(show, block=block) - except TypeError: # not all versions have this - plt_show(show) - return fig + return raw.plot_psd_topo( + tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, proj=proj, method='welch', + dB=dB, layout=layout, color=color, fig_facecolor=fig_facecolor, + axis_facecolor=axis_facecolor, axes=axes, block=block, show=show, + n_jobs=n_jobs, verbose=verbose, n_fft=n_fft, n_overlap=n_overlap) def _setup_channel_selections(raw, kind, order): diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 716448df8b1..c9386087de2 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -386,16 +386,15 @@ def test_plot_psd_epochs_ctf(raw_ctf): """Test plotting CTF epochs psd (+topomap).""" evts = make_fixed_length_events(raw_ctf) epochs = Epochs(raw_ctf, evts, preload=True) - pytest.raises(RuntimeError, epochs.plot_psd_topomap, - bands=[(0, 0.01, 'foo')]) # no freqs in range - epochs.plot_psd_topomap() - # EEG060 is flat in this dataset for dB in [True, False]: with pytest.warns(UserWarning, match='for channel EEG060'): epochs.plot_psd(dB=dB) epochs.drop_channels(['EEG060']) epochs.plot_psd(spatial_colors=False, average=False) + with pytest.raises(RuntimeError, match='No frequencies in band'): + epochs.plot_psd_topomap(bands=[(0, 0.01, 'foo')]) + epochs.plot_psd_topomap() def test_plot_epochs_selection_butterfly(raw, browser_backend): diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index bf4fbcd0521..de3067904bd 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -816,11 +816,12 @@ def test_plot_raw_psd(raw, raw_orig): plt.close('all') # gh-7631 - data = 1e-3 * np.random.rand(2, 100) - info = create_info(['CH1', 'CH2'], 100) + n_times = sfreq = n_fft = 100 + data = 1e-3 * np.random.rand(2, n_times) + info = create_info(['CH1', 'CH2'], sfreq) # ch_types defaults to 'misc' raw = RawArray(data, info) picks = pick_types(raw.info, misc=True) - raw.plot_psd(picks=picks, spatial_colors=False) + raw.plot_psd(picks=picks, spatial_colors=False, n_fft=n_fft) plt.close('all') diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index a0e1dcb8a4f..c0fefbdcc1f 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -34,7 +34,6 @@ plt_show, _process_times, DraggableColorbar, _get_cmap, _validate_if_list_of_axes, _setup_cmap, _check_time_unit, _set_3d_axes_equal, _check_type_projs) -from ..time_frequency import psd_multitaper from ..defaults import _handle_default from ..transforms import apply_trans, invert_transform from ..io.meas_info import Info, _simplify_info @@ -758,6 +757,8 @@ def plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True, %(sphere_topomap)s %(border_topomap)s %(ch_type_topomap)s + + .. versionadded:: 0.21 %(cnorm)s .. versionadded:: 0.24 @@ -1550,7 +1551,7 @@ def plot_evoked_topomap(evoked, times="auto", ch_type=None, automatically by checking for local maxima in global field power. If "interactive", the time can be set interactively at run-time by using a slider. - %(ch_type_evoked_topomap)s + %(ch_type_topomap)s %(vmin_vmax_topomap)s %(cmap_topomap)s %(sensors_topomap)s @@ -1940,14 +1941,19 @@ def _plot_topomap_multi_cbar(data, pos, ax, title=None, unit=None, vmin=None, @verbose -def plot_epochs_psd_topomap(epochs, bands=None, - tmin=None, tmax=None, proj=False, - bandwidth=None, adaptive=False, low_bias=True, - normalization='length', ch_type=None, - cmap=None, agg_fun=None, dB=False, n_jobs=None, - normalize=False, cbar_fmt='auto', - outlines='head', axes=None, show=True, - sphere=None, vlim=(None, None), verbose=None): +def plot_epochs_psd_topomap(epochs, bands=None, tmin=None, tmax=None, + proj=False, *, bandwidth=None, adaptive=False, + low_bias=True, normalization='length', + ch_type=None, normalize=False, agg_fun=None, + dB=False, # sensors=True, show_names=False, + # mask=None, mask_params=None, contours=6, + outlines='head', sphere=None, + # image_interp=_INTERPOLATION_DEFAULT, + # extrapolate=_EXTRAPOLATE_DEFAULT, + # border=_BORDER_DEFAULT, res=64, size=1, + cmap=None, vlim=(None, None), # colorbar=True, + cbar_fmt='auto', # units=None, + axes=None, show=True, n_jobs=None, verbose=None): """Plot the topomap of the power spectral density across epochs. Parameters @@ -1955,12 +1961,8 @@ def plot_epochs_psd_topomap(epochs, bands=None, epochs : instance of Epochs The epochs object. %(bands_psd_topo)s - tmin : float | None - Start time to consider. - tmax : float | None - End time to consider. - proj : bool - Apply projection. + %(tmin_tmax_psd)s + %(proj_psd)s bandwidth : float The bandwidth of the multi taper windowing function in Hz. The default value is a window half-bandwidth of 4 Hz. @@ -1971,60 +1973,59 @@ def plot_epochs_psd_topomap(epochs, bands=None, Only use tapers with more than 90%% spectral concentration within bandwidth. %(normalization)s - ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | None - The channel type to plot. For 'grad', the gradiometers are collected in - pairs and the mean for each pair is plotted. If None, then first - available channel type from order given above is used. Defaults to - None. - %(cmap_psd_topo)s + %(ch_type_psd_topomap)s + %(normalize_psd_topo)s %(agg_fun_psd_topo)s %(dB_plot_topomap)s - %(n_jobs)s - %(normalize_psd_topo)s - %(cbar_fmt_psd_topo)s %(outlines_topomap)s - %(axes_plot_topomap)s - %(show)s %(sphere_topomap_auto)s + %(cmap_psd_topo)s %(vlim_psd_topo_joint)s + %(cbar_fmt_psd_topo)s + %(axes_plot_topomap)s + %(show)s + %(n_jobs)s %(verbose)s Returns ------- fig : instance of Figure - Figure distributing one image per channel across sensor topography. + Figure showing one scalp topography per frequency band. """ - ch_type = _get_ch_type(epochs, ch_type) - units = _handle_default('units', None) - scalings = _handle_default('scalings', None) - unit = units[ch_type] - scaling = scalings[ch_type] - - picks, pos, merge_channels, names, ch_type, sphere, clip_origin = \ - _prepare_topomap_plot(epochs, ch_type, sphere=sphere) - outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) - - psds, freqs = psd_multitaper(epochs, tmin=tmin, tmax=tmax, - bandwidth=bandwidth, adaptive=adaptive, - low_bias=low_bias, - normalization=normalization, picks=picks, - proj=proj, n_jobs=n_jobs) - psds = np.mean(psds, axis=0) - psds *= scaling**2 - - if merge_channels: - psds, names = _merge_ch_data(psds, ch_type, names, method='mean') - - return plot_psds_topomap( - psds=psds, freqs=freqs, pos=pos, agg_fun=agg_fun, - bands=bands, cmap=cmap, dB=dB, normalize=normalize, - cbar_fmt=cbar_fmt, outlines=outlines, axes=axes, show=show, - sphere=sphere, vlim=vlim, unit=unit, ch_type=ch_type) + # add after dB + # %(sensors_topomap)s + # %(show_names_topomap)s + # %(mask_evoked_topomap)s + # %(mask_params_topomap)s + # %(contours_topomap)s + # add after sphere + # %(image_interp_topomap)s + # %(extrapolate_topomap)s + # %(border_topomap)s + # %(res_topomap)s + # %(size_topomap)s + # add after vlim + # %(colorbar_topomap)s + # add after cbar_fmt + # %(units_topomap)s + return epochs.plot_psd_topomap( + bands=bands, tmin=tmin, tmax=tmax, proj=proj, method='multitaper', + ch_type=ch_type, normalize=normalize, agg_fun=agg_fun, dB=dB, + # sensors=sensors, show_names=show_names, mask=mask, + # mask_params=mask_params, contours=contours, + outlines=outlines, + sphere=sphere, # image_interp=image_interp, extrapolate=extrapolate, + # border=border, res=res, size=size, + cmap=cmap, vlim=vlim, # colorbar=colorbar, + cbar_fmt=cbar_fmt, # units=units, + axes=None, + show=True, n_jobs=None, verbose=None, bandwidth=bandwidth, + low_bias=low_bias, adaptive=adaptive, normalization=normalization) @fill_doc def plot_psds_topomap( - psds, freqs, pos, agg_fun=None, bands=None, + psds, freqs, pos, *, agg_fun=None, bands=None, cmap=None, dB=True, normalize=False, cbar_fmt='%0.3f', outlines='head', axes=None, show=True, sphere=None, vlim=(None, None), unit=None, ch_type='eeg'): @@ -2046,8 +2047,7 @@ def plot_psds_topomap( %(cbar_fmt_psd_topo)s %(outlines_topomap)s %(axes_plot_topomap)s - show : bool - Show figure if True. + %(show)s %(sphere_topomap)s %(vlim_psd_topo_joint)s unit : str | None diff --git a/mne/viz/utils.py b/mne/viz/utils.py index fae50b0f4d3..bc18af396eb 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -31,7 +31,6 @@ from ..fixes import _get_args from ..io import show_fiff, Info from ..io.constants import FIFF -from ..io.meas_info import create_info from ..io.pick import (channel_type, channel_indices_by_type, pick_channels, _pick_data_channels, _DATA_CH_TYPES_SPLIT, _DATA_CH_TYPES_ORDER_DEFAULT, _VALID_CHANNEL_TYPES, @@ -2216,6 +2215,7 @@ def _plot_psd(inst, fig, freqs, psd_list, picks_list, titles_list, # helper function for plot_raw_psd and plot_epochs_psd from matplotlib.ticker import ScalarFormatter from .evoked import _plot_lines + from ..stats import _ci for key, ls in zip(['lowpass', 'highpass', 'line_freq'], ['--', '--', '-.']): @@ -2238,15 +2238,17 @@ def _plot_psd(inst, fig, freqs, psd_list, picks_list, titles_list, if average: # mean across channels psd_mean = np.mean(psd, axis=0) - if area_mode == 'std': + if area_mode in ('sd', 'std'): # std across channels psd_std = np.std(psd, axis=0) hyp_limits = (psd_mean - psd_std, psd_mean + psd_std) elif area_mode == 'range': hyp_limits = (np.min(psd, axis=0), np.max(psd, axis=0)) - else: # area_mode is None + elif area_mode is None: hyp_limits = None + else: # area_mode is float + hyp_limits = _ci(psd, ci=area_mode) ax.plot(freqs, psd_mean, color=color, alpha=line_alpha, linewidth=0.5) @@ -2256,14 +2258,8 @@ def _plot_psd(inst, fig, freqs, psd_list, picks_list, titles_list, if not average: picks = np.concatenate(picks_list) - psd_list = np.concatenate(psd_list) - types = np.array(inst.get_channel_types(picks=picks)) - # Needed because the data do not match the info anymore. - info = create_info([inst.ch_names[p] for p in picks], - inst.info['sfreq'], types) - with info._unlock(): - info['chs'] = [inst.info['chs'][p] for p in picks] - info['dev_head_t'] = inst.info['dev_head_t'] + info = pick_info(inst.info, sel=picks, copy=True) + types = np.array(info.get_channel_types()) ch_types_used = list() for this_type in _VALID_CHANNEL_TYPES: if this_type in types: @@ -2272,10 +2268,13 @@ def _plot_psd(inst, fig, freqs, psd_list, picks_list, titles_list, unit = '' units = {t: yl for t, yl in zip(ch_types_used, ylabels)} titles = {c: t for c, t in zip(ch_types_used, titles_list)} - picks = np.arange(len(psd_list)) + # here we overwrite `picks` because of how _plot_lines works; + # we already have the data, ch_types, etc in sync. + psd_array = np.concatenate(psd_list) + picks = np.arange(len(psd_array)) if not spatial_colors: spatial_colors = color - _plot_lines(psd_list, info, picks, fig, ax_list, spatial_colors, + _plot_lines(psd_array, info, picks, fig, ax_list, spatial_colors, unit, units=units, scalings=None, hline=None, gfp=False, types=types, zorder='std', xlim=(freqs[0], freqs[-1]), ylim=None, times=freqs, bad_ch_idx=[], titles=titles, diff --git a/setup.cfg b/setup.cfg index 8a8b8675451..384cb50f282 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,10 +15,13 @@ release = egg_info -RDb '' doc_files = doc [flake8] -exclude = __init__.py,constants.py,fixes.py,resources.py,*doc/auto_examples*,*doc/_build* +exclude = __init__.py,constants.py,fixes.py,resources.py,*doc/auto_*,*doc/_build* ignore = W503,W504,I100,I101,I201,N806,E201,E202,E221,E222,E241 # We add A for the array-spacing plugin, and ignore the E ones it covers above select = A,E,F,W,C +# 10_spectrum_class.py has a wide rST table +per-file-ignores = + tutorials/time-freq/10_spectrum_class.py:E501 [tool:pytest] addopts = diff --git a/tutorials/clinical/60_sleep.py b/tutorials/clinical/60_sleep.py index d99833300e4..dee1e409e07 100644 --- a/tutorials/clinical/60_sleep.py +++ b/tutorials/clinical/60_sleep.py @@ -39,7 +39,6 @@ import mne from mne.datasets.sleep_physionet.age import fetch_data -from mne.time_frequency import psd_welch from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score @@ -194,13 +193,12 @@ [epochs_train, epochs_test]): for stage, color in zip(stages, stage_colors): - epochs[stage].plot_psd(area_mode=None, color=color, ax=ax, - fmin=0.1, fmax=20., show=False, - average=True, spatial_colors=False) + spectrum = epochs[stage].compute_psd(fmin=0.1, fmax=20.) + spectrum.plot(ci=None, color=color, axes=ax, + show=False, average=True, spatial_colors=False) ax.set(title=title, xlabel='Frequency (Hz)') -ax2.set(ylabel='µV^2/Hz (dB)') +ax1.set(ylabel='µV²/Hz (dB)') ax2.legend(ax2.lines[2::3], stages) -plt.show() ############################################################################## # Design a scikit-learn transformer from a Python function @@ -235,7 +233,8 @@ def eeg_power_band(epochs): "sigma": [11.5, 15.5], "beta": [15.5, 30]} - psds, freqs = psd_welch(epochs, picks='eeg', fmin=0.5, fmax=30.) + spectrum = epochs.compute_psd(picks='eeg', fmin=0.5, fmax=30.) + psds, freqs = spectrum.get_data(return_freqs=True) # Normalize the PSDs psds /= np.sum(psds, axis=-1, keepdims=True) diff --git a/tutorials/epochs/20_visualize_epochs.py b/tutorials/epochs/20_visualize_epochs.py index 017df66409f..32046990f7c 100644 --- a/tutorials/epochs/20_visualize_epochs.py +++ b/tutorials/epochs/20_visualize_epochs.py @@ -130,39 +130,56 @@ # Plotting the power spectrum of ``Epochs`` # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# Again, just like `~mne.io.Raw` objects, `~mne.Epochs` objects -# have a `~mne.Epochs.plot_psd` method for plotting the `spectral -# density`_ of the data. +# Again, just like `~mne.io.Raw` objects, :class:`~mne.Epochs` objects +# can be converted to `spectral density`_ via +# :meth:`~mne.Epochs.compute_psd`, which can then be plotted using the +# :class:`~mne.time_frequency.EpochsSpectrum`'s +# :meth:`~mne.time_frequency.EpochsSpectrum.plot` method. -epochs['auditory'].plot_psd(picks='eeg') +epochs['auditory'].compute_psd().plot(picks='eeg') # %% -# It is also possible to plot spectral estimates across sensors as a scalp -# topography, using `~mne.Epochs.plot_psd_topomap`. The default parameters will -# plot five frequency bands (δ, θ, α, β, γ), will compute power based on -# magnetometer channels, and will plot the power estimates in decibels: +# It is also possible to plot spectral power estimates across sensors as a +# scalp topography, using the :class:`~mne.time_frequency.EpochsSpectrum`'s +# :meth:`~mne.time_frequency.EpochsSpectrum.plot_topomap` method. The default +# parameters will plot five frequency bands (δ, θ, α, β, γ), will compute power +# based on magnetometer channels (if present), and will plot the power +# estimates on a dB-like log-scale: -epochs['visual/right'].plot_psd_topomap() +spectrum = epochs['visual/right'].compute_psd() +spectrum.plot_topomap() # %% +# .. note:: +# Prior to the addition of the :class:`~mne.time_frequency.EpochsSpectrum` +# class, the above plots were possible via:: +# +# epochs['auditory'].plot_psd(picks='eeg') +# epochs['visual/right'].plot_psd_topomap() +# +# The :meth:`~mne.Epochs.plot_psd` and `~mne.Epochs.plot_psd_topomap` +# methods of :class:`~mne.Epochs` objects are still provided to support +# legacy analysis scripts, but new code should instead use the +# :class:`~mne.time_frequency.EpochsSpectrum` object API. +# # Just like `~mne.Epochs.plot_projs_topomap`, -# `~mne.Epochs.plot_psd_topomap` has a ``vlim='joint'`` option for fixing -# the colorbar limits jointly across all subplots, to give a better sense of -# the relative magnitude in each frequency band. You can change which channel -# type is used via the ``ch_type`` parameter, and if you want to view -# different frequency bands than the defaults, the ``bands`` parameter takes a -# :class:`dict`, with keys providing a subplot title and values providing -# either single frequency bins to plot, or lower/upper frequency band edges: +# `EpochsSpectrum.plot_topomap()` +# has a ``vlim='joint'`` option for fixing the colorbar limits jointly across +# all subplots, to give a better sense of the relative magnitude in each +# frequency band. You can change which channel type is used via the +# ``ch_type`` parameter, and if you want to view different frequency bands than +# the defaults, the ``bands`` parameter takes a :class:`dict`, with keys +# providing a subplot title and values providing either single frequency bins +# to plot, or lower/upper frequency band edges: bands = {'10 Hz': 10, '15 Hz': 15, '20 Hz': 20, '10-20 Hz': (10, 20)} -epochs['visual/right'].plot_psd_topomap(bands=bands, vlim='joint', - ch_type='grad') +spectrum.plot_topomap(bands=bands, vlim='joint', ch_type='grad') # %% # If you prefer untransformed power estimates, you can pass ``dB=False``. It is # also possible to normalize the power estimates by dividing by the total power # across all frequencies, by passing ``normalize=True``. See the docstring of -# `~mne.Epochs.plot_psd_topomap` for details. +# `~mne.time_frequency.EpochsSpectrum.plot_topomap` for details. # # # Plotting ``Epochs`` as an image map diff --git a/tutorials/raw/40_visualize_raw.py b/tutorials/raw/40_visualize_raw.py index 51b241ff25d..6fb07017fae 100644 --- a/tutorials/raw/40_visualize_raw.py +++ b/tutorials/raw/40_visualize_raw.py @@ -18,6 +18,7 @@ # %% import os + import mne sample_data_folder = mne.datasets.sample.data_path() @@ -32,13 +33,12 @@ # but `~mne.io.Raw` objects also have several built-in plotting methods: # # - `~mne.io.Raw.plot` -# - `~mne.io.Raw.plot_psd` -# - `~mne.io.Raw.plot_psd_topo` # - `~mne.io.Raw.plot_sensors` # - `~mne.io.Raw.plot_projs_topomap` # -# The first three are discussed here in detail; the last two are shown briefly -# and covered in-depth in other tutorials. +# The first one is discussed here in detail; the last two are shown briefly +# and covered in-depth in other tutorials. This tutorial also covers a few +# ways of plotting the spectral content of :class:`~mne.io.Raw` data. # # # Interactive data browsing with ``Raw.plot()`` @@ -114,10 +114,12 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # To visualize the frequency content of continuous data, the `~mne.io.Raw` -# object provides a `~mne.io.Raw.plot_psd` to plot the `spectral density`_ of -# the data. +# object provides a :meth:`~mne.io.Raw.compute_psd` method to compute +# `spectral density`_ and the resulting :class:`~mne.time_frequency.Spectrum` +# object has a :meth:`~mne.time_frequency.Spectrum.plot` method: -raw.plot_psd(average=True) +spectrum = raw.compute_psd() +spectrum.plot(average=True) # %% # If the data have been filtered, vertical dashed lines will automatically @@ -128,32 +130,58 @@ # color-coding the channels by location, and more. For example, here is a plot # of just a few sensors (specified with the ``picks`` parameter), color-coded # by spatial location (via the ``spatial_colors`` parameter, see the -# documentation of `~mne.io.Raw.plot_psd` for full details): +# documentation of `~mne.time_frequency.Spectrum.plot` for full details): midline = ['EEG 002', 'EEG 012', 'EEG 030', 'EEG 048', 'EEG 058', 'EEG 060'] -raw.plot_psd(picks=midline) +spectrum.plot(picks=midline) + +# %% +# It is also possible to plot spectral power estimates across sensors as a +# scalp topography, using the :class:`~mne.time_frequency.Spectrum`'s +# :meth:`~mne.time_frequency.Spectrum.plot_topomap` method. The default +# parameters will plot five frequency bands (δ, θ, α, β, γ), will compute power +# based on magnetometer channels (if present), and will plot the power +# estimates on a dB-like log-scale: + +spectrum.plot_topomap() # %% # Alternatively, you can plot the PSD for every sensor on its own axes, with # the axes arranged spatially to correspond to sensor locations in space, using -# `~mne.io.Raw.plot_psd_topo`: +# `~mne.time_frequency.Spectrum.plot_topo`: -raw.plot_psd_topo() +spectrum.plot_topo() # %% # This plot is also interactive; hovering over each "thumbnail" plot will # display the channel name in the bottom left of the plot window, and clicking # on a thumbnail plot will create a second figure showing a larger version of # the selected channel's spectral density (as if you had called -# `~mne.io.Raw.plot_psd` on that channel). +# `~mne.time_frequency.Spectrum.plot` with that channel passed as ``picks``). # -# By default, `~mne.io.Raw.plot_psd_topo` will show only the MEG +# By default, `~mne.time_frequency.Spectrum.plot_topo` will show only the MEG # channels if MEG channels are present; if only EEG channels are found, they # will be plotted instead: -raw.copy().pick_types(meg=False, eeg=True).plot_psd_topo() +spectrum.pick('eeg').plot_topo() # %% +# .. note:: +# +# Prior to the addition of the :class:`~mne.time_frequency.Spectrum` class, +# the above plots were possible via:: +# +# raw.plot_psd(average=True) +# raw.plot_psd_topo() +# raw.pick('eeg').plot_psd_topo() +# +# (there was no ``plot_topomap`` method for :class:`~mne.io.Raw`). The +# :meth:`~mne.io.Raw.plot_psd` and :meth:`~mne.io.Raw.plot_psd_topo` methods +# of :class:`~mne.io.Raw` objects are still provided to support legacy +# analysis scripts, but new code should instead use the +# :class:`~mne.time_frequency.Spectrum` object API. +# +# # Plotting sensor locations from ``Raw`` objects # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # diff --git a/tutorials/time-freq/10_spectrum_class.py b/tutorials/time-freq/10_spectrum_class.py new file mode 100644 index 00000000000..db222c0f7d0 --- /dev/null +++ b/tutorials/time-freq/10_spectrum_class.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +# noqa: E501 +""" +.. _tut-spectrum-class: + +============================================================== +The Spectrum and EpochsSpectrum classes: frequency-domain data +============================================================== + +This tutorial shows how to create and visualize frequency-domain +representations of your data, starting from continuous :class:`~mne.io.Raw`, +discontinuous :class:`~mne.Epochs`, or averaged :class:`~mne.Evoked` data. + +As usual we'll start by importing the modules we need, and loading our +:ref:`sample dataset `: +""" + +# %% +import numpy as np + +import mne + +sample_data_folder = mne.datasets.sample.data_path() +sample_data_raw_file = (sample_data_folder / 'MEG' / 'sample' / + 'sample_audvis_raw.fif') +raw = mne.io.read_raw_fif(sample_data_raw_file, verbose=False).crop(tmax=60) + +# %% +# All three sensor-space containers (:class:`~mne.io.Raw`, +# :class:`~mne.Epochs`, and :class:`~mne.Evoked`) have a +# :meth:`~mne.io.Raw.compute_psd` method with the same options. + +raw.compute_psd() + +# %% +# By default, the spectral estimation method will be the +# :footcite:t:`Welch1967` method for continuous data, and the multitaper +# method :footcite:`Slepian1978` for epoched or averaged data. This default can +# be overridden by passing ``method='welch'`` or ``method='multitaper'`` to the +# :meth:`~mne.io.Raw.compute_psd` method. +# +# There are many other options available as well; for example we can compute a +# spectrum from a given span of times, for a chosen frequency range, and for a +# subset of the available channels: + +raw.compute_psd(method='multitaper', tmin=10, tmax=20, fmin=5, fmax=30, + picks='eeg') + +# %% +# You can also pass some parameters to the underlying spectral estimation +# function, such as the FFT window length and overlap for the Welch method; see +# the docstrings of :class:`mne.time_frequency.Spectrum` (esp. its +# ``method_kw`` parameter) and the spectral estimation functions +# :func:`~mne.time_frequency.psd_array_welch` and +# :func:`~mne.time_frequency.psd_array_multitaper` for details. +# +# For epoched data, the class of the spectral estimate will be +# :class:`mne.time_frequency.EpochsSpectrum` instead of +# :class:`mne.time_frequency.Spectrum`, but most of the API is the same for the +# two classes. For example, both have a +# :meth:`~mne.time_frequency.EpochsSpectrum.get_data` method with an option to +# return the bin frequencies: + +with mne.use_log_level('WARNING'): # hide some irrelevant info messages + events = mne.find_events(raw, stim_channel='STI 014') + event_dict = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, + 'visual/right': 4} + epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, + preload=True) +epo_spectrum = epochs.compute_psd() +psds, freqs = epo_spectrum.get_data(return_freqs=True) +print(f'\nPSDs shape: {psds.shape}, freqs shape: {freqs.shape}') +epo_spectrum + +# %% +# Additionally, both :class:`~mne.time_frequency.Spectrum` and +# :class:`~mne.time_frequency.EpochsSpectrum` have ``__getitem__`` methods, +# meaning their data can be accessed by square-bracket indexing. For +# :class:`~mne.time_frequency.Spectrum` objects (computed from +# :class:`~mne.io.Raw` or :class:`~mne.Evoked` data), the indexing works +# similar to a :class:`~mne.io.Raw` object or a +# :class:`NumPy array`: + +evoked = epochs['auditory'].average() +evk_spectrum = evoked.compute_psd() +# the first 3 frequency bins for the first 4 channels: +print(evk_spectrum[:4, :3]) + +# %% +# .. hint:: +# :class: sidebar +# +# If the original :class:`~mne.Epochs` object had a metadata dataframe +# attached, the derived :class:`~mne.time_frequency.EpochsSpectrum` will +# inherit that metadata and will hence also support subselecting epochs via +# :ref:`Pandas query strings `. +# +# In contrast, the :class:`~mne.time_frequency.EpochsSpectrum` has indexing +# similar to :class:`~mne.Epochs` objects: you can use string values to select +# spectral estimates for specific epochs based on their condition names, and +# what you get back is a new instance of +# :class:`~mne.time_frequency.EpochsSpectrum` rather than a +# :class:`NumPy array` of the data values. Selection via +# :term:`hierarchical event descriptors` (HEDs) is also possible: + +# get both "visual/left" and "visual/right" epochs: +epo_spectrum['visual'] + +# %% +# Visualizing Spectrum objects +# ---------------------------- +# +# Both :class:`~mne.time_frequency.Spectrum` and +# :class:`~mne.time_frequency.EpochsSpectrum` objects have plotting methods +# :meth:`~mne.time_frequency.Spectrum.plot` (frequency × power), +# :meth:`~mne.time_frequency.Spectrum.plot_topo` (frequency × power separately +# for each sensor), and :meth:`~mne.time_frequency.Spectrum.plot_topomap` +# (interpolated scalp topography of power, in specific frequency bands). A few +# plot options are demonstrated below; see the docstrings for full details. + +evk_spectrum.plot() +evk_spectrum.plot_topo(color='k', fig_facecolor='w', axis_facecolor='w') + +# %% +evk_spectrum.plot_topomap(ch_type='eeg', agg_fun=np.median) + +# %% +# Migrating legacy code +# --------------------- +# +# Below is a quick-reference table of equivalent code from before and after the +# introduction of the :class:`~mne.time_frequency.Spectrum` and +# :class:`~mne.time_frequency.EpochsSpectrum` classes. +# +# .. table:: Quick reference for common Spectral class actions +# :widths: auto +# +# +---------------------------------------------------+----------------------------------------------------------------------+ +# | Old | New | +# +===================================================+======================================================================+ +# | ``mne.time_frequency.psd_welch(raw)`` | ``raw.compute_psd().get_data(return_freqs=True)`` | +# +---------------------------------------------------+----------------------------------------------------------------------+ +# | ``mne.time_frequency.psd_multitaper(raw)`` | ``raw.compute_psd(method='multitaper').get_data(return_freqs=True)`` | +# +---------------------------------------------------+----------------------------------------------------------------------+ +# | ``raw.plot_psd(fmin, fmax, dB, area_mode='std')`` | ``raw.compute_psd(fmin, fmax).plot(dB, ci='std')`` | +# +---------------------------------------------------+----------------------------------------------------------------------+ +# | ``raw.plot_psd_topo(n_fft, overlap, axes)`` | ``raw.compute_psd(n_fft, overlap).plot_topo(axes)`` | +# +---------------------------------------------------+----------------------------------------------------------------------+ +# | ``epochs.plot_psd_topomap(tmax, bands)`` | ``epochs.compute_psd(tmax).plot_topomap(bands)`` | +# +---------------------------------------------------+----------------------------------------------------------------------+ +# +# +# .. warning:: +# +# The functions :func:`mne.time_frequency.psd_welch` and +# :func:`mne.time_frequency.psd_multitaper` have been deprecated; new code +# should use the :meth:`Raw.compute_psd()`, +# :meth:`Epochs.compute_psd()`, and +# :meth:`Evoked.compute_psd()` methods, and pass +# ``method='welch'`` or ``method='multitaper'`` as a parameter. +# +# The class methods :meth:`Raw.plot_psd()`, +# :meth:`Epochs.plot_psd()`, +# :meth:`Raw.plot_psd_topo()`, and +# :meth:`Epochs.plot_psd_topomap()` have been +# kept in the API to support legacy code, but should be avoided when writing +# new code. +# +# +# References +# ---------- +# .. footbibliography:: From a02fa0a386a1feb770caaeaa91678b753ab8e48e Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Sat, 27 Aug 2022 17:28:23 +0100 Subject: [PATCH 4/7] BUG: Improve logic for bti (#11102) * BUG: Fix BTi channel logic * FIX: Old Python * FIX: Fine * FIX: Green --- doc/changes/latest.inc | 2 +- mne/io/bti/bti.py | 37 ++++++++++++++----- mne/io/bti/tests/test_bti.py | 69 ++++++++++++++++++++++++++++++++---- 3 files changed, 93 insertions(+), 15 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index b969ac037ec..14fcd700df6 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -46,8 +46,8 @@ Bugs - Fix bug in :class:`mne.viz.Brain` constructor where the first argument was named ``subject_id`` instead of ``subject`` (:gh:`11049` by `Eric Larson`_) - Fix bug in :ref:`mne coreg` where the MEG helmet position was not updated during ICP fitting (:gh:`11084` by `Eric Larson`_) - Document ``height`` and ``weight`` keys of ``subject_info`` entry in :class:`mne.Info` (:gh:`11019` by :newcontrib:`Sena Er`) -- Fixed bug in :func:`mne.viz.plot_filter` when plotting filters created using ``output='ba'`` mode with ``compensation`` turned on. (by `Marian Dovgialo`_) - Fix bug in :func:`mne.viz.plot_filter` when plotting filters created using ``output='ba'`` mode with ``compensation`` turned on. (:gh:`11040` by `Marian Dovgialo`_) +- Fix bug in :func:`mne.io.read_raw_bti` where EEG, EMG, and H/VEOG channels were not detected properly, and many non-ECG channels were called ECG. The logic has been improved, and any channels of unknown type are now labeled as ``misc`` (:gh:`11102` by `Eric Larson`_) - Fix bug in :func:`mne.viz.plot_topomap` when providing ``sphere="eeglab"`` (:gh:`11081` by `Mathieu Scheltienne`_) - Applying a montage where EEG locations are not in head space (or unknown space) without fiducials will now raise an error message. (:gh:`11080` by `Marijn van Vliet`_) diff --git a/mne/io/bti/bti.py b/mne/io/bti/bti.py index 17a70ac5d6b..2204d1b59dd 100644 --- a/mne/io/bti/bti.py +++ b/mne/io/bti/bti.py @@ -8,6 +8,7 @@ # # simplified BSD-3 license +import functools import os.path as op from io import BytesIO from itertools import count @@ -48,7 +49,7 @@ def _instantiate_default_info_chs(): unit=FIFF.FIFF_UNIT_V, cal=1.0, scanno=None, - kind=FIFF.FIFFV_ECG_CH, + kind=FIFF.FIFFV_MISC_CH, logno=None) @@ -997,6 +998,23 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): _mult_cal_one(data_view, one, idx, cals, mult) +@functools.lru_cache(1) +def _1020_names(): + from mne.channels import make_standard_montage + return set(ch_name.lower() + for ch_name in make_standard_montage('standard_1005').ch_names) + + +def _eeg_like(ch_name): + # Some bti recordigs look like "F4-POz", so let's at least mark them + # as EEG + if ch_name.count('-') != 1: + return + ch, ref = ch_name.split('-') + eeg_names = _1020_names() + return ch.lower() in eeg_names and ref.lower() in eeg_names + + def _make_bti_digitization( info, head_shape_fname, convert, use_hpi, bti_dev_t, dev_ctf_t): with info._unlock(): @@ -1184,7 +1202,7 @@ def _get_bti_info(pdf_fname, config_fname, head_shape_fname, rotation_x, chan_info['coil_type'] = \ FIFF.FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD - elif chan_4d.startswith('EEG'): + elif chan_4d.startswith('EEG') or _eeg_like(chan_4d): chan_info['kind'] = FIFF.FIFFV_EEG_CH chan_info['coil_type'] = FIFF.FIFFV_COIL_EEG chan_info['coord_frame'] = eeg_frame @@ -1194,14 +1212,17 @@ def _get_bti_info(pdf_fname, config_fname, head_shape_fname, rotation_x, chan_info['kind'] = FIFF.FIFFV_STIM_CH elif chan_4d == 'TRIGGER': chan_info['kind'] = FIFF.FIFFV_STIM_CH - elif chan_4d.startswith('EOG') or chan_4d in eog_ch: + elif chan_4d.startswith('EOG') or \ + chan_4d[:4] in ('HEOG', 'VEOG') or chan_4d in eog_ch: chan_info['kind'] = FIFF.FIFFV_EOG_CH - elif chan_4d == ecg_ch: + elif chan_4d.startswith('EMG'): + chan_info['kind'] = FIFF.FIFFV_EMG_CH + elif chan_4d == ecg_ch or chan_4d.startswith('ECG'): chan_info['kind'] = FIFF.FIFFV_ECG_CH - elif chan_4d.startswith('X'): - chan_info['kind'] = FIFF.FIFFV_MISC_CH - elif chan_4d == 'UACurrent': - chan_info['kind'] = FIFF.FIFFV_MISC_CH + # Our default is now misc, but if we ever change that, + # we'll need this: + # elif chan_4d.startswith('X') or chan_4d == 'UACurrent': + # chan_info['kind'] = FIFF.FIFFV_MISC_CH chs.append(chan_info) diff --git a/mne/io/bti/tests/test_bti.py b/mne/io/bti/tests/test_bti.py index 56d25ba0542..324936e400d 100644 --- a/mne/io/bti/tests/test_bti.py +++ b/mne/io/bti/tests/test_bti.py @@ -2,6 +2,7 @@ # # License: BSD-3-Clause +from collections import Counter from io import BytesIO import os import os.path as op @@ -38,12 +39,10 @@ for a in archs] tmp_raw_fname = op.join(base_dir, 'tmp_raw.fif') -fname_2500 = op.join(testing.data_path(download=False), 'BTi', 'erm_HFH', - 'c,rfDC') -fname_sim = op.join(testing.data_path(download=False), 'BTi', '4Dsim', - 'c,rfDC') -fname_sim_filt = op.join(testing.data_path(download=False), 'BTi', '4Dsim', - 'c,rfDC,fn50,o') +testing_path_bti = testing.data_path(download=False) / 'BTi' +fname_2500 = testing_path_bti / 'erm_HFH' / 'c,rfDC' +fname_sim = testing_path_bti / '4Dsim' / 'c,rfDC' +fname_sim_filt = testing_path_bti / '4Dsim' / 'c,rfDC,fn50,o' # the 4D exporter doesn't export all channels, so we confine our comparison NCH = 248 @@ -383,3 +382,61 @@ def test_bti_set_eog(): preload=False, eog_ch=('X65', 'X67', 'X69', 'X66', 'X68')) assert_equal(len(pick_types(raw.info, eog=True)), 5) + + +@testing.requires_testing_data +def test_bti_ecg_eog_emg(monkeypatch): + """Test that EOG/ECG/EMG are set properly in BTi.""" + kwargs = dict(rename_channels=False, head_shape_fname=None) + raw = read_raw_bti(fname_2500, **kwargs) + ch_types = raw.get_channel_types() + got = Counter(ch_types) + # Before improving the triaging in gh-, these values were: + # want = dict(mag=148, ref_meg=11, ecg=32, stim=2, misc=1) + want = dict(mag=148, ref_meg=11, ecg=1, stim=2, misc=1, eeg=31) + assert set(want) == set(got) + for key in want: + assert want[key] == got[key], key + + # replace channel names with some from HCP (starting from the end) + # not including UACurrent (misc) or TRIGGER/RESPONSE (stim) b/c they + # already exist + got_map = dict(zip(raw.ch_names, ch_types)) + kind_map = dict( + stim=['TRIGGER', 'RESPONSE'], + misc=['UACurrent'], + ) + for kind, ch_names in kind_map.items(): + for ch_name in ch_names: + assert got_map[ch_name] == kind + kind_map = dict( + misc=['SA1', 'SA2', 'SA3'], + ecg=['ECG+', 'ECG-'], + eog=['VEOG+', 'HEOG+', 'VEOG-', 'HEOG-'], + emg=['EMG_LF', 'EMG_LH', 'EMG_RF', 'EMG_RH'], + ) + new_names = sum(kind_map.values(), list()) + assert len(new_names) == 13 + assert set(new_names).intersection(set(raw.ch_names)) == set() + + def _read_bti_header_2(*args, **kwargs): + bti_info = _read_bti_header(*args, **kwargs) + for ch_name, ch in zip(new_names, bti_info['chs'][::-1]): + ch['chan_label'] = ch_name + return bti_info + + monkeypatch.setattr(mne.io.bti.bti, '_read_bti_header', _read_bti_header_2) + raw = read_raw_bti(fname_2500, **kwargs) + got_map = dict(zip(raw.ch_names, raw.get_channel_types())) + got = Counter(got_map.values()) + want = dict(mag=148, ref_meg=11, misc=1, stim=2, eeg=19) + for kind, ch_names in kind_map.items(): + want[kind] = want.get(kind, 0) + len(ch_names) + assert set(want) == set(got) + for key in want: + assert want[key] == got[key], key + for kind, ch_names in kind_map.items(): + for ch_name in ch_names: + assert ch_name in raw.ch_names + err_msg = f'{ch_name} type {got_map[ch_name]} !+ {kind}' + assert got_map[ch_name] == kind, err_msg From 532cfc5cee360dff6bad9cac7710c805aee9a347 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Sun, 28 Aug 2022 15:37:13 +0100 Subject: [PATCH 5/7] add test for edf units param (#11105) * add test for edf units param * flake8 --- mne/io/edf/edf.py | 2 +- mne/io/edf/tests/test_edf.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index f71b782cb8c..0b2c5f79589 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -147,7 +147,7 @@ def __init__(self, input_fname, eog=None, misc=None, stim_channel='auto', units = dict() for k, (this_ch, this_unit) in enumerate(orig_units.items()): - if this_unit != "" and this_unit in units: + if this_unit != "" and this_ch in units: raise ValueError(f'Unit for channel {this_ch} is present in ' 'the file. Cannot overwrite it with the ' 'units argument.') diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index 7426e76b66f..975059c328f 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -77,6 +77,14 @@ def test_orig_units(): assert orig_units['A1'] == 'µV' # formerly 'uV' edit by _check_orig_units +def test_units_params(): + """Test enforcing original channel units.""" + with pytest.raises(ValueError, + match=r"Unit for channel .* is present .* Cannot " + "overwrite it"): + _ = read_raw_edf(edf_path, units='V', preload=True) + + def test_subject_info(tmp_path): """Test exposure of original channel units.""" raw = read_raw_edf(edf_path) From 9385b96ad12c26e8f7d3b913b330eade4ee1365a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Sun, 28 Aug 2022 23:59:05 +0200 Subject: [PATCH 6/7] MRG: Fixes for #11090 (#11108) * Fixes for #11090 * Update changelog --- doc/changes/latest.inc | 2 +- mne/channels/channels.py | 3 ++- mne/io/pick.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 14fcd700df6..137cef730eb 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -33,7 +33,7 @@ Enhancements - Add ``starting_affine`` keyword argument to :func:`mne.transforms.compute_volume_registration` to initialize an alignment with an affine (:gh:`11020` by `Alex Rockhill`_) - The ``trans`` parameter in :func:`mne.make_field_map` now accepts a :class:`~pathlib.Path` object, and uses standardised loading logic (:gh:`10784` by :newcontrib:`Andrew Quinn`) - Add HTML representation for `~mne.Evoked` in Jupyter Notebooks (:gh:`11075` by `Valerii Chirkov`_ and `Andrew Quinn`_) -- Add support for ``temperature`` and ``gsr`` (galvanic skin response, i.e., electrodermal activity) channel types (:gh:`11090` by `Eric Larson`_) +- Add support for ``temperature`` and ``gsr`` (galvanic skin response, i.e., electrodermal activity) channel types (:gh:`11090`, :gh:`11108` by `Eric Larson`_ and `Richard Höchenberger`_) - Allow :func:`mne.beamformer.make_dics` to take ``pick_ori='vector'`` to compute vector source estimates (:gh:`19080` by `Alex Rockhill`_) - Add ``units`` parameter to :func:`mne.io.read_raw_edf` in case units are missing from the file (:gh:`11099` by `Alex Gramfort`_) - Add ``on_missing`` functionality to all of our classes that have a ``drop_channels`` method, to control what happens when channel names are not in the object (:gh:`11077` by `Andrew Quinn`_) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 988ded901ce..7b781d4c600 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -202,7 +202,8 @@ def equalize_channels(instances, copy=True, verbose=None): FIFF.FIFF_UNIT_T_M: 'T/m', FIFF.FIFF_UNIT_MOL: 'M', FIFF.FIFF_UNIT_NONE: 'NA', - FIFF.FIFF_UNIT_CEL: 'C'} + FIFF.FIFF_UNIT_CEL: 'C', + FIFF.FIFF_UNIT_S: 'S'} def _check_set(ch, projs, ch_type): diff --git a/mne/io/pick.py b/mne/io/pick.py index 2d33cb6d7a6..6746be92472 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -96,7 +96,7 @@ def get_channel_type_constants(include_defaults=False): unit=FIFF.FIFF_UNIT_V_M2, coil_type=FIFF.FIFFV_COIL_EEG_CSD), temperature=dict(kind=FIFF.FIFFV_TEMPERATURE_CH, - unit=FIFF.FIFF_UNIT_C), + unit=FIFF.FIFF_UNIT_CEL), gsr=dict(kind=FIFF.FIFFV_GALVANIC_CH, unit=FIFF.FIFF_UNIT_S), ) From bc30e0e8c687cd91f75432df97e4f348ee8d73eb Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Mon, 29 Aug 2022 12:31:27 +0100 Subject: [PATCH 7/7] Revert "Add error message when conversion of EEG locs to [circle deploy] (#11104) --- doc/changes/latest.inc | 1 - mne/channels/montage.py | 8 -------- mne/channels/tests/test_montage.py | 11 ----------- 3 files changed, 20 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 137cef730eb..2d237de1fea 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -49,7 +49,6 @@ Bugs - Fix bug in :func:`mne.viz.plot_filter` when plotting filters created using ``output='ba'`` mode with ``compensation`` turned on. (:gh:`11040` by `Marian Dovgialo`_) - Fix bug in :func:`mne.io.read_raw_bti` where EEG, EMG, and H/VEOG channels were not detected properly, and many non-ECG channels were called ECG. The logic has been improved, and any channels of unknown type are now labeled as ``misc`` (:gh:`11102` by `Eric Larson`_) - Fix bug in :func:`mne.viz.plot_topomap` when providing ``sphere="eeglab"`` (:gh:`11081` by `Mathieu Scheltienne`_) -- Applying a montage where EEG locations are not in head space (or unknown space) without fiducials will now raise an error message. (:gh:`11080` by `Marijn van Vliet`_) API changes ~~~~~~~~~~~ diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 21c183ddcdc..80968a7cf0a 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -668,19 +668,11 @@ def transform_to_head(montage): # Get fiducial points and their coord_frame native_head_t = compute_native_head_t(montage) montage = montage.copy() # to avoid inplace modification - if native_head_t['from'] != FIFF.FIFFV_COORD_HEAD: for d in montage.dig: if d['coord_frame'] == native_head_t['from']: d['r'] = apply_trans(native_head_t, d['r']) d['coord_frame'] = FIFF.FIFFV_COORD_HEAD - elif d['kind'] == FIFF.FIFFV_POINT_EEG: - raise RuntimeError( - f'Could not transform EEG channel {d["ident"]} position ' - f'from {_verbose_frames[d["coord_frame"]]} to head ' - 'coordinates. Fiducial points are either missing or ' - 'specified in a different coordinate frame than the EEG ' - 'channel locations.') return montage diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 703a67dfe2c..bb61857478e 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -1210,17 +1210,6 @@ def test_transform_to_head_and_compute_dev_head_t(): montage_polhemus )) - # Test errors when transforming without fiducials explicitly where points - # are tagged to be not in head or unknown coord space. - montage_without_fids = make_dig_montage( - ch_pos={"ch_1": np.array([1, 2, 3]), - "ch_2": np.array([4, 5, 6]), - "ch_3": np.array([7, 8, 9])}, - coord_frame="mri") # MRI coordinate space - with pytest.raises(RuntimeError, match='Could not transform EEG channel'): - with pytest.warns(RuntimeWarning, match='Fiducial point .* not found'): - transform_to_head(montage_without_fids) - def test_set_montage_with_mismatching_ch_names(): """Test setting a DigMontage with mismatching ch_names."""