Skip to content

Commit

Permalink
Allow not dropping bads when creating or plotting Spectrum objs (#12006)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel McCloy <dan@mccloy.info>
  • Loading branch information
3 people authored Oct 4, 2023
1 parent 63ce95d commit d79606b
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 6 deletions.
2 changes: 2 additions & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@
.. _George O'Neill: https://georgeoneill.github.io

.. _Gonzalo Reina: https://greina.me/

.. _Guillaume Dumas: https://mila.quebec/en/person/guillaume-dumas

.. _Guillaume Favelier: https://github.com/GuillaumeFavelier
Expand Down
3 changes: 3 additions & 0 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2366,6 +2366,7 @@ def compute_psd(
picks=None,
proj=False,
remove_dc=True,
exclude=(),
*,
n_jobs=1,
verbose=None,
Expand All @@ -2382,6 +2383,7 @@ def compute_psd(
%(picks_good_data_noref)s
%(proj_psd)s
%(remove_dc)s
%(exclude_psd)s
%(n_jobs)s
%(verbose)s
%(method_kw_psd)s
Expand Down Expand Up @@ -2410,6 +2412,7 @@ def compute_psd(
tmin=tmin,
tmax=tmax,
picks=picks,
exclude=exclude,
proj=proj,
remove_dc=remove_dc,
n_jobs=n_jobs,
Expand Down
3 changes: 3 additions & 0 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,7 @@ def compute_psd(
picks=None,
proj=False,
remove_dc=True,
exclude=(),
*,
n_jobs=1,
verbose=None,
Expand All @@ -1067,6 +1068,7 @@ def compute_psd(
%(picks_good_data_noref)s
%(proj_psd)s
%(remove_dc)s
%(exclude_psd)s
%(n_jobs)s
%(verbose)s
%(method_kw_psd)s
Expand Down Expand Up @@ -1095,6 +1097,7 @@ def compute_psd(
tmin=tmin,
tmax=tmax,
picks=picks,
exclude=exclude,
proj=proj,
remove_dc=remove_dc,
reject_by_annotation=False,
Expand Down
3 changes: 3 additions & 0 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,7 @@ def compute_psd(
tmin=None,
tmax=None,
picks=None,
exclude=(),
proj=False,
remove_dc=True,
reject_by_annotation=True,
Expand All @@ -2153,6 +2154,7 @@ def compute_psd(
%(fmin_fmax_psd)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(exclude_psd)s
%(proj_psd)s
%(remove_dc)s
%(reject_by_annotation_psd)s
Expand Down Expand Up @@ -2184,6 +2186,7 @@ def compute_psd(
tmin=tmin,
tmax=tmax,
picks=picks,
exclude=exclude,
proj=proj,
remove_dc=remove_dc,
reject_by_annotation=reject_by_annotation,
Expand Down
13 changes: 12 additions & 1 deletion mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
tmin,
tmax,
picks,
exclude,
proj,
remove_dc,
*,
Expand Down Expand Up @@ -348,7 +349,9 @@ def __init__(

# prep times and picks
self._time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq)
self._picks = _picks_to_idx(inst.info, picks, "data", with_ref_meg=False)
self._picks = _picks_to_idx(
inst.info, picks, "data", exclude, with_ref_meg=False
)

# add the info object. bads and non-data channels were dropped by
# _picks_to_idx() so we update the info accordingly:
Expand Down Expand Up @@ -1081,6 +1084,7 @@ class Spectrum(BaseSpectrum):
%(fmin_fmax_psd)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(exclude_psd)s
%(proj_psd)s
%(remove_dc)s
%(reject_by_annotation_psd)s
Expand Down Expand Up @@ -1122,6 +1126,7 @@ def __init__(
tmin,
tmax,
picks,
exclude,
proj,
remove_dc,
reject_by_annotation,
Expand All @@ -1145,6 +1150,7 @@ def __init__(
tmin,
tmax,
picks,
exclude,
proj,
remove_dc,
n_jobs=n_jobs,
Expand Down Expand Up @@ -1290,6 +1296,7 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin):
%(fmin_fmax_psd)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(exclude_psd)s
%(proj_psd)s
%(remove_dc)s
%(n_jobs)s
Expand Down Expand Up @@ -1327,6 +1334,7 @@ def __init__(
tmin,
tmax,
picks,
exclude,
proj,
remove_dc,
*,
Expand All @@ -1347,6 +1355,7 @@ def __init__(
tmin,
tmax,
picks,
exclude,
proj,
remove_dc,
n_jobs=n_jobs,
Expand Down Expand Up @@ -1459,6 +1468,7 @@ def average(self, method="mean"):
tmin=None,
tmax=None,
picks=None,
exclude=(),
proj=None,
remove_dc=None,
reject_by_annotation=None,
Expand Down Expand Up @@ -1561,6 +1571,7 @@ def read_spectrum(fname):
tmin=None,
tmax=None,
picks=None,
exclude=(),
proj=None,
remove_dc=None,
reject_by_annotation=None,
Expand Down
12 changes: 11 additions & 1 deletion mne/time_frequency/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ def test_spectrum_reject_by_annot(raw):
assert spect_no_annot != spect_reject_annot


def test_spectrum_bads_exclude(raw):
"""Test bads are not removed unless exclude="bads"."""
raw.pick("mag") # get rid of IAS channel
spect_no_excld = raw.compute_psd()
spect_with_excld = raw.compute_psd(exclude="bads")
assert raw.info["bads"] == spect_no_excld.info["bads"]
assert spect_with_excld.info["bads"] == []
assert set(raw.ch_names) - set(spect_with_excld.ch_names) == set(raw.info["bads"])


def test_spectrum_getitem_raw(raw_spectrum):
"""Test Spectrum.__getitem__ for Raw-derived spectra."""
want = raw_spectrum.get_data(slice(1, 3), fmax=7)
Expand Down Expand Up @@ -280,7 +290,7 @@ def test_spectrum_to_data_frame(inst, request, evoked):
extra_dim = () if is_epochs else (1,)
extra_cols = ["freq", "condition", "epoch"] if is_epochs else ["freq"]
# compute PSD
spectrum = inst if is_already_psd else inst.compute_psd()
spectrum = inst if is_already_psd else inst.compute_psd(exclude="bads")
n_epo, n_chan, n_freq = extra_dim + spectrum.get_data().shape
# test wide format
df_wide = spectrum.to_data_frame()
Expand Down
11 changes: 7 additions & 4 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,12 +1356,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
_exclude_spectrum = """\
exclude : list of str | 'bads'
Channel names to exclude{}. If ``'bads'``, channels
in ``spectrum.info['bads']`` are excluded; pass an empty list or tuple to
plot all channels (including "bad" channels, if any).
in ``{}info['bads']`` are excluded; pass an empty list to
include all channels (including "bad" channels, if any).
"""

docdict["exclude_spectrum_get_data"] = _exclude_spectrum.format("")
docdict["exclude_spectrum_plot"] = _exclude_spectrum.format(" from being drawn")
docdict["exclude_psd"] = _exclude_spectrum.format("", "")
docdict["exclude_spectrum_get_data"] = _exclude_spectrum.format("", "spectrum.")
docdict["exclude_spectrum_plot"] = _exclude_spectrum.format(
" from being drawn", "spectrum."
)

docdict[
"export_edf_note"
Expand Down

0 comments on commit d79606b

Please sign in to comment.