Skip to content

Commit

Permalink
wip: first sketch of spectrum class [ci skip]
Browse files Browse the repository at this point in the history
implement __repr__; add placeholder _repr_html_

add draft of _repr_html_

default to multitaper for Evokeds

make raw.plot_psd() use the new code path

unify viz.plot_raw_psd code path too

support unaggregated multitaper

add picks param to spectrum.plot()

fix(ish) the units() method

allow average=False as synonym for None

handle unaggregated estimates in combo with epochs

fix CI plotting

implement get_data method [ci skip]
  • Loading branch information
drammock committed Jun 11, 2022
1 parent bd32944 commit 88a7ef9
Show file tree
Hide file tree
Showing 17 changed files with 891 additions and 272 deletions.
13 changes: 13 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -2284,3 +2284,16 @@ @article{LuckGaspelin2017
issn = {1469-8986},
doi = {10.1111/psyp.12639},
}

@article{Welch1967,
title = {The Use of Fast {{Fourier}} Transform for the Estimation of Power Spectra: {{A}} Method Based on Time Averaging over Short, Modified Periodograms},
shorttitle = {The Use of Fast {{Fourier}} Transform for the Estimation of Power Spectra},
author = {Welch, Peter D.},
year = {1967},
journal = {IEEE Transactions on Audio and Electroacoustics},
volume = {15},
number = {2},
pages = {70--73},
issn = {0018-9278},
doi = {10.1109/TAU.1967.1161901},
}
1 change: 1 addition & 0 deletions doc/time_frequency.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Time-Frequency
AverageTFR
EpochsTFR
CrossSpectralDensity
Spectrum

Functions that operate on mne-python objects:

Expand Down
12 changes: 8 additions & 4 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def set_meas_date(self, meas_date):


class UpdateChannelsMixin(object):
"""Mixin class for Raw, Evoked, Epochs, AverageTFR."""
"""Mixin class for Raw, Evoked, Epochs, Spectrum, AverageTFR."""

@verbose
def pick_types(self, meg=False, eeg=False, stim=False, eog=False,
Expand Down Expand Up @@ -839,7 +839,7 @@ def drop_channels(self, ch_names):
def _pick_drop_channels(self, idx, *, verbose=None):
# avoid circular imports
from ..io import BaseRaw
from ..time_frequency import AverageTFR, EpochsTFR
from ..time_frequency import AverageTFR, EpochsTFR, Spectrum

msg = 'adding, dropping, or reordering channels'
if isinstance(self, BaseRaw):
Expand All @@ -864,8 +864,12 @@ def _pick_drop_channels(self, idx, *, verbose=None):
if mat is not None:
setattr(self, key, mat[idx][:, idx])

# All others (Evoked, Epochs, Raw) have chs axis=-2
axis = -3 if isinstance(self, (AverageTFR, EpochsTFR)) else -2
if isinstance(self, Spectrum):
axis = self._dims.index('channel')
elif isinstance(self, (AverageTFR, EpochsTFR)):
axis = -3
else: # All others (Evoked, Epochs, Raw) have chs axis=-2
axis = -2
if hasattr(self, '_data'): # skip non-preloaded Raw
self._data = self._data.take(idx, axis=axis)
else:
Expand Down
15 changes: 7 additions & 8 deletions mne/decoding/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@

import numpy as np

from .mixin import TransformerMixin
from .base import BaseEstimator

from .. import pick_types
from ..filter import filter_data, _triage_filter_params
from ..time_frequency.psd import psd_array_multitaper
from ..utils import fill_doc, _check_option, _validate_type, verbose
from ..io.pick import (pick_info, _pick_data_channels, _picks_by_type,
_picks_to_idx)
from ..cov import _check_scalings_user
from ..filter import _triage_filter_params, filter_data
from ..io.pick import (_pick_data_channels, _picks_by_type, _picks_to_idx,
pick_info)
from ..time_frequency import psd_array_multitaper
from ..utils import _check_option, _validate_type, fill_doc, verbose
from .base import BaseEstimator
from .mixin import TransformerMixin


class _ConstantScaler():
Expand Down
99 changes: 50 additions & 49 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,61 +11,61 @@
#
# License: BSD-3-Clause

from functools import partial
from collections import Counter
from copy import deepcopy
import json
import operator
import os.path as op
from collections import Counter
from copy import deepcopy
from functools import partial

import numpy as np

from .io.utils import _construct_bids_filename
from .io.write import (start_and_end_file, start_block, end_block,
write_int, write_float, write_float_matrix,
write_double_matrix, write_complex_float_matrix,
write_complex_double_matrix, write_id, write_string,
_get_split_size, _NEXT_FILE_BUFFER, INT32_MAX)
from .io.meas_info import (read_meas_info, write_meas_info, _merge_info,
_ensure_infos_match, ContainsMixin)
from .io.open import fiff_open, _get_next_fname
from .io.tree import dir_tree_find
from .io.tag import read_tag, read_tag_info
from .io.constants import FIFF
from .io.fiff.raw import _get_fname_rep
from .io.pick import (channel_indices_by_type, channel_type,
pick_channels, pick_info, _pick_data_channels,
_DATA_CH_TYPES_SPLIT, _picks_to_idx)
from .io.proj import setup_proj, ProjMixin
from .io.base import BaseRaw, TimeMixin, _get_ch_factors
from .annotations import (EpochAnnotationsMixin, _read_annotations_fif,
_write_annotations)
from .baseline import _check_baseline, _log_rescale, rescale
from .bem import _check_origin
from .evoked import EvokedArray, _check_decim
from .baseline import rescale, _log_rescale, _check_baseline
from .channels.channels import (UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin)
from .filter import detrend, FilterMixin, _check_fun
from .parallel import parallel_func

from .channels.channels import (InterpolationMixin, SetChannelsMixin,
UpdateChannelsMixin)
from .event import (_read_events_fif, make_fixed_length_events,
match_event_names)
from .evoked import EvokedArray, _check_decim
from .filter import FilterMixin, _check_fun, detrend
from .fixes import rng_uniform
from .viz import (plot_epochs, plot_epochs_psd, plot_epochs_psd_topomap,
plot_epochs_image, plot_topo_image_epochs, plot_drop_log)
from .utils import (_check_fname, check_fname, logger, verbose,
_time_mask, check_random_state, warn, _pl,
sizeof_fmt, SizeMixin, copy_function_doc_to_method_doc,
_check_pandas_installed,
_check_preload, GetEpochsMixin,
from .io.base import BaseRaw, TimeMixin, _get_ch_factors
from .io.constants import FIFF
from .io.fiff.raw import _get_fname_rep
from .io.meas_info import (ContainsMixin, _ensure_infos_match, _merge_info,
read_meas_info, write_meas_info)
from .io.open import _get_next_fname, fiff_open
from .io.pick import (_DATA_CH_TYPES_SPLIT, _pick_data_channels, _picks_to_idx,
channel_indices_by_type, channel_type, pick_channels,
pick_info)
from .io.proj import ProjMixin, setup_proj
from .io.tag import read_tag, read_tag_info
from .io.tree import dir_tree_find
from .io.utils import _construct_bids_filename
from .io.write import (_NEXT_FILE_BUFFER, INT32_MAX, _get_split_size,
end_block, start_and_end_file, start_block,
write_complex_double_matrix, write_complex_float_matrix,
write_double_matrix, write_float, write_float_matrix,
write_id, write_int, write_string)
from .parallel import parallel_func
from .time_frequency.spectrum import ToSpectrumMixin
from .utils import (GetEpochsMixin, ShiftTimeMixin, SizeMixin,
_build_data_frame, _check_combine, _check_event_id,
_check_fname, _check_option, _check_pandas_index_arguments,
_check_pandas_installed, _check_preload,
_check_time_format, _convert_times, _ensure_events,
_gen_events, _on_missing, _path_like, _pl,
_prepare_read_metadata, _prepare_write_metadata,
_check_event_id, _gen_events, _check_option,
_check_combine, ShiftTimeMixin, _build_data_frame,
_check_pandas_index_arguments, _convert_times,
_scale_dataframe_data, _check_time_format, object_size,
_on_missing, _validate_type, _ensure_events,
_path_like)
_scale_dataframe_data, _time_mask, _validate_type,
check_fname, check_random_state,
copy_function_doc_to_method_doc, logger, object_size,
sizeof_fmt, verbose, warn)
from .utils.docs import fill_doc
from .annotations import (_write_annotations, _read_annotations_fif,
EpochAnnotationsMixin)
from .viz import (plot_drop_log, plot_epochs, plot_epochs_image,
plot_epochs_psd, plot_epochs_psd_topomap,
plot_topo_image_epochs)


def _pack_reject_params(epochs):
Expand Down Expand Up @@ -340,7 +340,8 @@ def _handle_event_repeated(events, event_id, event_repeated, selection,
@fill_doc
class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin, ShiftTimeMixin,
SetChannelsMixin, InterpolationMixin, FilterMixin,
TimeMixin, SizeMixin, GetEpochsMixin, EpochAnnotationsMixin):
TimeMixin, SizeMixin, GetEpochsMixin, EpochAnnotationsMixin,
ToSpectrumMixin):
"""Abstract base class for `~mne.Epochs`-type classes.
.. warning:: This class provides basic functionality and should never be
Expand Down Expand Up @@ -3692,11 +3693,11 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
of children in MEG: Quantification, effects on source
estimation, and compensation. NeuroImage 40:541–550, 2008.
""" # noqa: E501
from .preprocessing.maxwell import (_trans_sss_basis, _reset_meg_bads,
_check_usable, _col_norm_pinv,
_get_n_moments, _get_mf_picks_fix_mags,
_prep_mf_coils, _check_destination,
_remove_meg_projs, _get_coil_scale)
from .preprocessing.maxwell import (_check_destination, _check_usable,
_col_norm_pinv, _get_coil_scale,
_get_mf_picks_fix_mags, _get_n_moments,
_prep_mf_coils, _remove_meg_projs,
_reset_meg_bads, _trans_sss_basis)
if head_pos is None:
raise TypeError('head_pos must be provided and cannot be None')
from .chpi import head_pos_to_trans_rot_t
Expand Down
24 changes: 12 additions & 12 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,6 @@
from .defaults import (_INTERPOLATION_DEFAULT, _EXTRAPOLATE_DEFAULT,
_BORDER_DEFAULT)
from .filter import detrend, FilterMixin, _check_fun
from .utils import (check_fname, logger, verbose, _time_mask, warn, sizeof_fmt,
SizeMixin, copy_function_doc_to_method_doc, _validate_type,
fill_doc, _check_option, ShiftTimeMixin, _build_data_frame,
_check_pandas_installed, _check_pandas_index_arguments,
_convert_times, _scale_dataframe_data, _check_time_format,
_check_preload, _check_fname)
from .viz import (plot_evoked, plot_evoked_topomap, plot_evoked_field,
plot_evoked_image, plot_evoked_topo)
from .viz.evoked import plot_evoked_white, plot_evoked_joint
from .viz.topomap import _topomap_animation

from .io.constants import FIFF
from .io.open import fiff_open
from .io.tag import read_tag
Expand All @@ -43,6 +32,17 @@
write_id, write_float, write_complex_float_matrix)
from .io.base import TimeMixin, _check_maxshield, _get_ch_factors
from .parallel import parallel_func
from .time_frequency.spectrum import ToSpectrumMixin
from .utils import (
check_fname, logger, verbose, _time_mask, warn, sizeof_fmt, SizeMixin,
copy_function_doc_to_method_doc, _validate_type, fill_doc, _check_option,
ShiftTimeMixin, _build_data_frame, _check_pandas_installed,
_check_pandas_index_arguments, _convert_times, _scale_dataframe_data,
_check_time_format, _check_preload, _check_fname)
from .viz import (plot_evoked, plot_evoked_topomap, plot_evoked_field,
plot_evoked_image, plot_evoked_topo)
from .viz.evoked import plot_evoked_white, plot_evoked_joint
from .viz.topomap import _topomap_animation

_aspect_dict = {
'average': FIFF.FIFFV_ASPECT_AVERAGE,
Expand All @@ -63,7 +63,7 @@
@fill_doc
class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin,
InterpolationMixin, FilterMixin, TimeMixin, SizeMixin,
ShiftTimeMixin):
ShiftTimeMixin, ToSpectrumMixin):
"""Evoked data.
Parameters
Expand Down
28 changes: 28 additions & 0 deletions mne/html_templates/repr/spectrum.html.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<table class="table table-hover table-striped table-sm table-responsive small">
<tr>
<th>Data type</th>
<td>{{ data_type }}</td>
</tr>
<tr>
<th>Data source</th>
<td>{{ inst_type }}</td>
</tr>
<tr>
<th>Number of channels</th>
<td>{{ spectrum.ch_names|length }}</td>
</tr>
{% if "taper" in spectrum._dims %}
<tr>
<th>Number of tapers</th>
<td>{{ spectrum._mt_weights.size }}</td>
</tr>
{% endif %}
<tr>
<th>Number of frequency bins</th>
<td>{{ spectrum.freqs|length }}</td>
</tr>
<tr>
<th>Frequency range</th>
<td>{{ '%.2f'|format(spectrum.freqs[0]) }} – {{ '%.2f'|format(spectrum.freqs[-1]) }} Hz</td>
</tr>
</table>
Loading

0 comments on commit 88a7ef9

Please sign in to comment.