Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "autoreject_local" as new option for preprocessing/ptp_reject #807

Merged
merged 21 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/settings/preprocessing/artifacts.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ tags:
- reject
- reject_tmin
- reject_tmax
- autoreject_n_interpolate
4 changes: 4 additions & 0 deletions docs/source/v1.5.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ All users are encouraged to update.
- Input validation has been improved by leveraging [pydantic](https://docs.pydantic.dev) (#779 by @larsoner)
- Reduced logging when reports are created and saved (#799 by @hoechenberger)
- Added [`"picard-extended_infomax"`][mne_bids_pipeline._config.ica_algorithm] ICA algorithm to perform "extended Infomax"-like ICA decomposition using Picard (#801 by @hoechenberger)
- Added support for using "local" [`autoreject`](https://autoreject.github.io) to find (and repair) bad channels on a
per-epochs basis; this can be enabled by setting [`reject`][mne_bids_pipeline._config.reject] to `"autoreject_local"`.
The behavior can further be controlled via the new setting
[`autoreject_n_interpolate`][mne_bids_pipeline._config.autoreject_n_interpolate]. (#807 by @hoechenberger)

[//]: # (### :warning: Behavior changes)

Expand Down
72 changes: 44 additions & 28 deletions mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,54 +1355,51 @@
# Rejection based on peak-to-peak amplitude
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

reject: Optional[Union[Dict[str, float], Literal["autoreject_global"]]] = None
reject: Optional[
Union[Dict[str, float], Literal["autoreject_global", "autoreject_local"]]
] = None
"""
Peak-to-peak amplitude limits to mark epochs as bad. This allows you to remove
epochs with strong transient artifacts.

If `None` (default), do not apply artifact rejection. If a dictionary,
manually specify rejection thresholds (see examples). If
`'autoreject_global'`, use [`autoreject`](https://autoreject.github.io) to find
suitable "global" rejection thresholds for each channel type, i.e. `autoreject`
will generate a dictionary with (hopefully!) optimal thresholds for each
channel type.
!!! info
The rejection is performed **after** SSP or ICA, if any of those methods
is used. To reject epochs **before** fitting ICA, see the
[`ica_reject`][mne_bids_pipeline._config.ica_reject] setting.

If `None` (default), do not apply artifact rejection.

If a dictionary, manually specify rejection thresholds (see examples).
The thresholds provided here must be at least as stringent as those in
[`ica_reject`][mne_bids_pipeline._config.ica_reject] if using ICA. In case of
`'autoreject_global'`, thresholds for any channel that do not meet this
requirement will be automatically replaced with those used in `ica_reject`.

!!! info
The rejection is performed **after** SSP or ICA, if any of those methods
is used. To reject epochs **before** fitting ICA, see the
[`ica_reject`][mne_bids_pipeline._config.ica_reject] setting.

If `None` (default), do not apply automated rejection. If a dictionary,
manually specify rejection thresholds (see examples). If `'auto'`, use
[`autoreject`](https://autoreject.github.io) to find suitable "global"
rejection thresholds for each channel type, i.e. `autoreject` will generate
a dictionary with (hopefully!) optimal thresholds for each channel type. Note
that using `autoreject` can be a time-consuming process.
If `"autoreject_global"`, use [`autoreject`](https://autoreject.github.io) to find
suitable "global" rejection thresholds for each channel type, i.e., `autoreject`
will generate a dictionary with (hopefully!) optimal thresholds for each
channel type.

!!! info
`autoreject` basically offers two modes of operation: "global" and
"local". In "global" mode, it will try to estimate one rejection
threshold **per channel type.** In "local" mode, it will generate
thresholds **for each individual channel.** Currently, the BIDS Pipeline
only supports the "global" mode.
If `"autoreject_local"`, use "local" `autoreject` to detect (and potentially repair) bad
channels in each epoch. Use [`autoreject_n_interpolate`][mne_bids_pipeline._config.autoreject_n_interpolate]
to control how many channels are allowed to be bad before an epoch gets dropped.

???+ example "Example"
```python
reject = {'grad': 4000e-13, 'mag': 4e-12, 'eog': 150e-6}
reject = {'eeg': 100e-6, 'eog': 250e-6}
reject = {"grad": 4000e-13, 'mag': 4e-12, 'eog': 150e-6}
reject = {"eeg": 100e-6, "eog": 250e-6}
reject = None # no rejection based on PTP amplitude
reject = "autoreject_global" # find global (per channel type) PTP thresholds
reject = "autoreject_local" # find local (per channel) thresholds and repair epochs
```
"""

reject_tmin: Optional[float] = None
"""
Start of the time window used to reject epochs. If `None`, the window will
start with the first time point.
start with the first time point. Has no effect if
[`reject`][mne_bids_pipeline._config.reject] has been set to `"autoreject_local"`.

???+ example "Example"
```python
reject_tmin = -0.1 # 100 ms before event onset.
Expand All @@ -1412,13 +1409,32 @@
reject_tmax: Optional[float] = None
"""
End of the time window used to reject epochs. If `None`, the window will end
with the last time point.
with the last time point. Has no effect if
[`reject`][mne_bids_pipeline._config.reject] has been set to `"autoreject_local"`.

???+ example "Example"
```python
reject_tmax = 0.3 # 300 ms after event onset.
```
"""

autoreject_n_interpolate: FloatArrayLike = [4, 8, 16]
"""
The maximum number of bad channels in an epoch that `autoreject` local will try to
interpolate. The optimal number among this list will be estimated using a
cross-validation procedure; this means that the more elements are provided here, the
longer the `autoreject` run will take. If the number of bad channels in an epoch
exceeds this value, the channels won't be interpolated and the epoch will be dropped.

!!! info
This setting only takes effect if [`reject`][mne_bids_pipeline._config.reject] has
been set to `"autoreject_local"`.

!!! info
Channels marked as globally bad in the BIDS dataset (in `*_channels.tsv)`) will not
be considered (i.e., will remain marked as bad and not analyzed by autoreject).
"""

###############################################################################
# DECODING
# --------
Expand Down
3 changes: 2 additions & 1 deletion mne_bids_pipeline/_config_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N
if (
ica_reject is not None
and reject is not None
and reject != "autoreject_global"
and reject not in ["autoreject_global", "autoreject_local"]
):
for ch_type in reject:
if ch_type in ica_reject and reject[ch_type] > ica_reject[ch_type]:
Expand Down Expand Up @@ -354,6 +354,7 @@ def _default_factory(key, val):
{"custom": (8, 24.0, 40)}, # decoding_csp_freqs
{"suffix": "ave"}, # source_info_path_update
["evoked"], # inverse_targets
[4, 8, 16], # autoreject_n_interpolate
]
for typ in (dict, list):
if isinstance(val, typ):
Expand Down
119 changes: 84 additions & 35 deletions mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from types import SimpleNamespace
from typing import Optional

import numpy as np
import autoreject

import mne
from mne_bids import BIDSPath

Expand Down Expand Up @@ -79,46 +82,72 @@ def drop_ptp(

# Get rejection parameters and drop bad epochs
epochs = mne.read_epochs(in_files.pop("epochs"), preload=True)
reject = _get_reject(
subject=subject,
session=session,
reject=cfg.reject,
ch_types=cfg.ch_types,
param="reject",
epochs=epochs,
)
if cfg.spatial_filter == "ica":
ica_reject = _get_reject(

if cfg.reject == "autoreject_local":
msg = "Using autoreject to find and repair bad epochs"
logger.info(**gen_log_kwargs(message=msg))

ar = autoreject.AutoReject(
n_interpolate=np.array(cfg.autoreject_n_interpolate),
random_state=cfg.random_state,
n_jobs=exec_params.n_jobs,
verbose=False,
)
n_epochs_before_reject = len(epochs)
epochs, reject_log = ar.fit_transform(epochs, return_log=True)
n_epochs_after_reject = len(epochs)
assert (
n_epochs_before_reject - n_epochs_after_reject
== reject_log.bad_epochs.sum()
)

msg = (
f"autoreject marked {reject_log.bad_epochs.sum()} epochs as bad "
f"(cross-validated n_interpolate limit: {ar.n_interpolate_})"
)
logger.info(**gen_log_kwargs(message=msg))
else:
reject = _get_reject(
subject=subject,
session=session,
reject=cfg.ica_reject,
reject=cfg.reject,
ch_types=cfg.ch_types,
param="ica_reject",
param="reject",
epochs=epochs,
)
else:
ica_reject = None

if ica_reject is not None:
for ch_type, threshold in ica_reject.items():
if ch_type in reject and threshold < reject[ch_type]:
# This can only ever happen in case of
# reject = 'autoreject_global'
msg = (
f"Adjusting PTP rejection threshold proposed by "
f"autoreject, as it is greater than ica_reject: "
f"{ch_type}: {reject[ch_type]} -> {threshold}"
)
logger.info(**gen_log_kwargs(message=msg))
reject[ch_type] = threshold

msg = f"Using PTP rejection thresholds: {reject}"
logger.info(**gen_log_kwargs(message=msg))

n_epochs_before_reject = len(epochs)
epochs.reject_tmin = cfg.reject_tmin
epochs.reject_tmax = cfg.reject_tmax
epochs.drop_bad(reject=reject)
n_epochs_after_reject = len(epochs)
if cfg.spatial_filter == "ica":
ica_reject = _get_reject(
subject=subject,
session=session,
reject=cfg.ica_reject,
ch_types=cfg.ch_types,
param="ica_reject",
)
else:
ica_reject = None

if ica_reject is not None:
for ch_type, threshold in ica_reject.items():
if ch_type in reject and threshold < reject[ch_type]:
# This can only ever happen in case of
# reject = 'autoreject_global'
msg = (
f"Adjusting PTP rejection threshold proposed by "
f"autoreject, as it is greater than ica_reject: "
f"{ch_type}: {reject[ch_type]} -> {threshold}"
)
logger.info(**gen_log_kwargs(message=msg))
reject[ch_type] = threshold

msg = f"Using PTP rejection thresholds: {reject}"
logger.info(**gen_log_kwargs(message=msg))

n_epochs_before_reject = len(epochs)
epochs.reject_tmin = cfg.reject_tmin
epochs.reject_tmax = cfg.reject_tmax
epochs.drop_bad(reject=reject)
n_epochs_after_reject = len(epochs)

if 0 < n_epochs_after_reject < 0.5 * n_epochs_before_reject:
msg = (
Expand Down Expand Up @@ -155,6 +184,24 @@ def drop_ptp(
with _open_report(
cfg=cfg, exec_params=exec_params, subject=subject, session=session
) as report:
if cfg.reject == "autoreject_local":
caption = (
f"Autoreject was run to produce cleaner epochs. "
f"{reject_log.bad_epochs.sum()} epochs were rejected because more than "
f"{ar.n_interpolate_} channels were bad (cross-validated n_interpolate "
f"limit; excluding globally bad and non-data channels, shown in white)."
)
report.add_figure(
fig=reject_log.plot(
orientation="horizontal", aspect="auto", show=False
),
title="Epochs: Autoreject cleaning",
caption=caption,
tags=("epochs", "autoreject"),
replace=True,
)
del caption

report.add_epochs(
epochs=epochs,
title="Epochs: after cleaning",
Expand All @@ -176,6 +223,8 @@ def get_config(
spatial_filter=config.spatial_filter,
ica_reject=config.ica_reject,
reject=config.reject,
autoreject_n_interpolate=config.autoreject_n_interpolate,
random_state=config.random_state,
ch_types=config.ch_types,
_epochs_split_size=config._epochs_split_size,
**_bids_kwargs(config=config),
Expand Down
18 changes: 11 additions & 7 deletions mne_bids_pipeline/tests/configs/config_ERP_CORE.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,18 @@
t_break_annot_start_after_previous_event = 3.0
t_break_annot_stop_before_next_event = 1.5

ica_reject = dict(eeg=350e-6, eog=500e-6)
reject = "autoreject_global"
if task == "N400": # test autoreject local without ICA
spatial_filter = None
reject = "autoreject_local"
autoreject_n_interpolate = [2, 4]
else:
ica_reject = dict(eeg=350e-6, eog=500e-6)
reject = "autoreject_global"

spatial_filter = "ica"
ica_max_iterations = 1000
ica_eog_threshold = 2
ica_decim = 2 # speed up ICA fitting
spatial_filter = "ica"
ica_max_iterations = 1000
ica_eog_threshold = 2
ica_decim = 2 # speed up ICA fitting

run_source_estimation = False

Expand All @@ -104,7 +109,6 @@
}

eeg_reference = ["P9", "P10"]
ica_n_components = 30 - len(eeg_reference)
epochs_tmin = -0.2
epochs_tmax = 0.8
epochs_metadata_tmin = 0
Expand Down