From 9489b088f35272a4c52acfcd76b1c3c6a3773413 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 12 Sep 2024 16:30:25 -0400 Subject: [PATCH] MAINT: Overwrite when possible --- mnefun/_ssp.py | 21 ++++++++------------- mnefun/_utils.py | 6 ++++++ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/mnefun/_ssp.py b/mnefun/_ssp.py index 0ec4cca..16f465b 100644 --- a/mnefun/_ssp.py +++ b/mnefun/_ssp.py @@ -21,7 +21,8 @@ from ._epoching import _raise_bad_epochs from ._paths import get_raw_fnames, get_bad_fname from ._utils import (get_args, _fix_raw_eog_cals, _handle_dict, _safe_remove, - _get_baseline, _restrict_reject_flat, _get_epo_kwargs) + _get_baseline, _restrict_reject_flat, _get_epo_kwargs, + _overwrite) def _get_fir_kwargs(fir_design): @@ -167,7 +168,7 @@ def _compute_erm_proj(p, subj, projs, kind, bad_file, remove_existing=False, # When doing eSSS it's a bit weird to put this in pca_dir but why not pca_dir = _get_pca_dir(p, subj) cont_proj = op.join(pca_dir, 'preproc_cont-proj.fif') - write_proj(cont_proj, pr) + _overwrite(write_proj, cont_proj, pr) return pr @@ -388,10 +389,7 @@ def do_preprocessing_combined(p, subjects, run_indices): print(' obtained %d epochs from %d events.' % (len(ecg_epochs), len(ecg_events))) if len(ecg_epochs) >= 20: - kwargs = dict() - if "overwrite" in get_args(write_events): - kwargs["overwrite"] = True - write_events(ecg_eve, ecg_epochs.events, **kwargs) + _overwrite(write_events, ecg_eve, ecg_epochs.events) ecg_epochs.save(ecg_epo, **_get_epo_kwargs()) desc_prefix = 'ECG-%s-%s' % tuple(ecg_t_lims) pr = compute_proj_wrap( @@ -399,10 +397,7 @@ def do_preprocessing_combined(p, subjects, run_indices): n_mag=proj_nums[0][1], n_eeg=proj_nums[0][2], desc_prefix=desc_prefix, **proj_kwargs) assert len(pr) == np.sum(proj_nums[0][::p_sl]) - kwargs = dict() - if "overwrite" in get_args(write_events): - kwargs["overwrite"] = True - write_proj(ecg_proj, pr, **kwargs) + _overwrite(write_proj, ecg_proj, pr) projs.extend(pr) else: _raise_bad_epochs(raw, ecg_epochs, ecg_events, 'ECG', @@ -421,7 +416,7 @@ def do_preprocessing_combined(p, subjects, run_indices): del proj_nums # save the projectors - write_proj(all_proj, projs) + _overwrite(write_proj, all_proj, projs) # # Look at raw_orig for trial DQs now, it will be quick @@ -500,7 +495,7 @@ def _compute_add_eog(p, subj, raw_orig, projs, eog_nums, kind, pca_dir, len(eog_events))) del eog_events if len(eog_epochs) >= 5: - write_events(eog_eve, eog_epochs.events) + _overwrite(write_events, eog_eve, eog_epochs.events) eog_epochs.save(eog_epo, **_get_epo_kwargs()) desc_prefix = f'{kind}-%s-%s' % tuple(eog_t_lims) pr = compute_proj_wrap( @@ -508,7 +503,7 @@ def _compute_add_eog(p, subj, raw_orig, projs, eog_nums, kind, pca_dir, n_mag=eog_nums[1], n_eeg=eog_nums[2], desc_prefix=desc_prefix, **proj_kwargs) assert len(pr) == np.sum(eog_nums[::p_sl]) - write_proj(eog_proj, pr) + _overwrite(write_proj, eog_proj, pr) projs.extend(pr) else: warnings.warn('Only %d usable EOG events!' % len(eog_epochs)) diff --git a/mnefun/_utils.py b/mnefun/_utils.py index 8222f6e..053709a 100644 --- a/mnefun/_utils.py +++ b/mnefun/_utils.py @@ -463,3 +463,9 @@ def convert_ANTS_surrogate(subject, trans, subjects_dir): def get_args(obj): """Wrapper.""" return inspect.signature(obj).parameters + + +def _overwrite(func, *args, **kwargs): + if "overwrite" in get_args(func): + kwargs["overwrite"] = True + return func(*args, **kwargs)