Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: Overwrite when possible #369

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/funloc/analysis_fun.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
params.subject_indices = [0, 1]

# Set what processing steps will execute
default = False
default = False # except for first and last steps which have other defaults
mnefun.do_processing(
params,
fetch_raw=default, # Fetch raw recording files from acquisition machine
fetch_raw=False, # Fetch raw recording files from acquisition machine
do_score=default, # Do scoring to slice data into trials

# Before running SSS, make SUBJ/raw_fif/SUBJ_prebad.txt file with
Expand All @@ -56,9 +56,9 @@
write_epochs=default, # Write epochs to disk
gen_covs=default, # Generate covariances

# Make SUBJ/trans/SUBJ-trans.fif using mne_analyze; needed for fwd calc.
# Make SUBJ/trans/SUBJ-trans.fif using mne coreg; needed for fwd calc.
gen_fwd=default, # Generate forward solutions (and source space)
gen_inv=default, # Generate inverses
gen_report=default, # Write mne report html of results to disk
print_status=default, # Print completeness status update
print_status=True, # Print completeness status update
)
2 changes: 1 addition & 1 deletion examples/funloc/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def score(p, subjects):
events[ii, 2] = _expyfun_dict[events[ii, 2]]
fname_out = op.join(out_dir,
'ALL_' + (run_name % subj) + '-eve.lst')
mne.write_events(fname_out, events)
mne.write_events(fname_out, events, overwrite=True)

# get subject performance
devs = (events[:, 2] >= 20)
Expand Down
6 changes: 3 additions & 3 deletions mnefun/_cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ._epoching import _concat_resamp_raws
from ._paths import get_epochs_evokeds_fnames, get_raw_fnames, safe_inserter
from ._scoring import _read_events
from ._utils import (get_args, _get_baseline, _restrict_reject_flat,
from ._utils import (get_args, _get_baseline, _restrict_reject_flat, _overwrite,
_handle_dict, _handle_decim, _check_reject_annot_regex)


Expand Down Expand Up @@ -119,7 +119,7 @@ def gen_covariances(p, subjects, run_indices, decim):
del use_flat['eeg']
cov = compute_raw_covariance(raw, reject=use_reject, flat=use_flat,
method=p.cov_method, **kwargs_erm)
write_cov(empty_cov_name, cov)
_overwrite(write_cov, empty_cov_name, cov)

# Make evoked covariances
for ii, (inv_name, inv_run) in enumerate(zip(p.inv_names, p.inv_runs)):
Expand Down Expand Up @@ -187,5 +187,5 @@ def gen_covariances(p, subjects, run_indices, decim):
epochs2.copy().crop(*baseline).plot()
raise RuntimeError('Error computing rank')

write_cov(cov_name, cov)
_overwrite(write_cov, cov_name, cov)
print()
8 changes: 8 additions & 0 deletions mnefun/_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Distributed under the (new) BSD License. See LICENSE.txt for more info.

from os import path as op
import datetime
import glob
import numpy as np
import re
Expand Down Expand Up @@ -114,6 +115,13 @@ def fix_eeg_channels(raw_files, anon=None, verbose=True):
if need_reorder:
raw._data[picks, :] = raw._data[picks, :][order]
if need_anon and raw.info['subject_info'] is not None:
anon = anon.copy()
if (
isinstance(raw.info["subject_info"].get("birthday"), datetime.date)
and isinstance(anon.get("birthday"), tuple) # noqa
):
anon["birthday"] = datetime.date(*anon["birthday"])
anon['birthday'] = raw.info["subject_info"]["birthday"]
raw.info['subject_info'].update(anon)
raw.info['description'] = write_key + anon_key
if isinstance(raw_file, str):
Expand Down
3 changes: 2 additions & 1 deletion mnefun/_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._cov import _compute_rank
from ._paths import (get_epochs_evokeds_fnames, safe_inserter,
get_cov_fwd_inv_fnames)
from ._utils import _overwrite

try:
from mne import spatial_src_adjacency
Expand Down Expand Up @@ -121,7 +122,7 @@ def gen_inverses(p, subjects, run_indices):
inv = make_inverse_operator(
epochs.info, fwd_restricted, cov, rank=rank,
**kwargs)
write_inverse_operator(inv_name, inv)
_overwrite(write_inverse_operator, inv_name, inv)
if p.disp_files:
print()

Expand Down
71 changes: 31 additions & 40 deletions mnefun/_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def report_context():
with plt.style.context(style):
yield
except Exception:
plt.close("all")
matplotlib.use(old_backend, force=True)
plt.interactive(is_interactive)
raise
Expand Down Expand Up @@ -129,7 +130,6 @@ def _report_good_hpi(report, fnames, p=None, subj=None, img_form='webp'):
break
fig = plot_good_coils(fit_data, show=False)
fig.set_size_inches(10, 2)
fig.tight_layout()
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
_add_figs_to_section(
Expand All @@ -151,8 +151,6 @@ def _report_chpi_snr(report, fnames, p=None, subj=None, img_form='webp'):
fig = plot_chpi_snr_raw(raw, t_window, show=False,
verbose=False)
fig.set_size_inches(10, 5)
fig.subplots_adjust(0.1, 0.1, 0.8, 0.95,
wspace=0, hspace=0.5)
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
_add_figs_to_section(
Expand All @@ -162,6 +160,7 @@ def _report_chpi_snr(report, fnames, p=None, subj=None, img_form='webp'):

def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None,
img_form='webp'):
import matplotlib.pyplot as plt
section = 'Head movement'
print((' %s ... ' % section).ljust(LJUST), end='')
t0 = time.time()
Expand All @@ -171,8 +170,9 @@ def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None,
fname, raw = _check_fname_raw(fname, p, subj)
_, pos, _ = _get_fit_data(raw, p, prefix=' ')
trans_to = _load_trans_to(p, subj, run_indices, raw)
axes = plt.subplots(3, 2, sharex=True, layout="constrained")[1]
fig = plot_head_positions(pos=pos, destination=trans_to,
info=raw.info, show=False)
info=raw.info, show=False, axes=axes)
for ax in fig.axes[::2]:
"""
# tighten to the sensor limits
Expand All @@ -198,7 +198,6 @@ def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None,
assert (mx >= coord).all()
ax.set_ylim(mn, mx)
fig.set_size_inches(10, 6)
fig.tight_layout()
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
del trans_to
Expand All @@ -220,7 +219,6 @@ def _report_events(report, fnames, p=None, subj=None, img_form='webp'):
if len(events) > 0:
fig = plot_events(events, raw.info['sfreq'], raw.first_samp)
fig.set_size_inches(10, 4)
fig.subplots_adjust(0.1, 0.1, 0.9, 0.99, wspace=0, hspace=0)
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
if len(figs):
Expand Down Expand Up @@ -254,7 +252,7 @@ def _report_raw_segments(report, raw, lowpass=None, img_form='webp'):
np.ones_like(new_events)]).T
with mne.utils.use_log_level('error'):
fig = raw_plot.plot(group_by='selection', butterfly=True,
events=new_events, lowpass=lowpass)
events=new_events, lowpass=lowpass, show=False)
fig.axes[0].lines[-1].set_zorder(10) # events
fig.axes[0].set(xticks=np.arange(0, len(times)) + 0.5)
xticklabels = ['%0.1f' % t for t in times]
Expand All @@ -265,15 +263,17 @@ def _report_raw_segments(report, raw, lowpass=None, img_form='webp'):
fig.delaxes(fig.axes[-1])
fig.set(figheight=(fig.axes[0].get_yticks() != 0).sum(),
figwidth=12)
fig.subplots_adjust(0.025, 0.0, 1, 1, 0, 0)
_add_figs_to_section(report, fig, section, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


def _gen_psd_plot(raw, fmax, n_fft, ax):
n_fft = min(n_fft, len(raw.times))
if hasattr(raw, 'compute_psd'):
plot = raw.compute_psd(fmax=fmax, n_fft=n_fft).plot(show=False,
axes=ax)
with warnings.catch_warnings(record=True):
plot = raw.compute_psd(fmax=fmax, n_fft=n_fft).plot(
show=False, axes=ax,
)
else:
plot = raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax)
return plot
Expand All @@ -294,7 +294,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
n_fft = min(8192, len(raw.times))
fmax = raw.info['lowpass']
n_ax = sum(key in raw for key in ('mag', 'grad', 'eeg'))
_, ax = plt.subplots(n_ax, figsize=(10, 8))
_, ax = plt.subplots(n_ax, figsize=(10, 8), layout="constrained")
figs = [_gen_psd_plot(raw, fmax=fmax, n_fft=n_fft, ax=ax)]
captions = ['%s: Raw' % section]
fmax = lp_cut + 2 * lp_trans
Expand All @@ -303,7 +303,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
(raw_pca, f'{section}: Raw processed (zoomed)'),
(raw_erm, f'{section}: ERM (zoomed)'),
(raw_erm_pca, f'{section}: ERM processed (zoomed)')]:
_, ax = plt.subplots(n_ax, figsize=(10, 8))
_, ax = plt.subplots(n_ax, figsize=(10, 8), layout="constrained")
if this_raw is not None:
figs.append(_gen_psd_plot(this_raw, fmax=fmax, n_fft=n_fft, ax=ax))
captions.append(caption)
Expand Down Expand Up @@ -398,7 +398,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
preload = p.report_params.get('preload', False)
for si, subj in enumerate(subjects):
struc = structurals[si] if structurals is not None else None
report = Report(verbose=False)
report = Report(title=subj, verbose=False)
print(' Processing subject %s/%s (%s)'
% (si + 1, len(subjects), subj))

Expand Down Expand Up @@ -771,7 +771,8 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
n_e = len(all_evoked)
n_row = n_s * n_e + 1
figs, axes = plt.subplots(
n_row, 1, figsize=(7, 3 * n_row))
n_row, 1, figsize=(7, 3 * n_row), layout="constrained",
)
captions = [
'%s: %s["%s"] (N=%s)'
% (section, analysis, all_evoked[0].comment,
Expand All @@ -782,7 +783,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
sl = slice(ei, n_e * n_s, n_e)
these_axes = list(axes[sl]) + [axes[-1]]
evo.plot_white(
noise_cov, verbose='error', axes=these_axes)
noise_cov, verbose='error', axes=these_axes, show=False)
for ax in these_axes[:-1]:
n_text = 'N=%d' % (evo.nave,)
if ei != 0:
Expand Down Expand Up @@ -826,7 +827,6 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
axes[-1].set_title(
f'{axes[-1].get_title()} (halves {SQ2STR})')
axes[-1]
figs.tight_layout()
figs = [figs]
_add_figs_to_section(
report, figs, captions, section=section,
Expand Down Expand Up @@ -864,7 +864,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
% op.basename(fname_evoked), end='')
else:
inv = mne.minimum_norm.read_inverse_operator(fname_inv)
figs, ax = plt.subplots(1, figsize=(7, 5))
figs, ax = plt.subplots(1, figsize=(7, 5), layout="constrained")
all_evoked = _get_std_even_odd(fname_evoked, name)
for ei, evoked in enumerate(all_evoked):
if ei != 0:
Expand All @@ -873,7 +873,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
try:
evoked.nave = max(orig, 1)
plot_snr_estimate(
evoked, inv, axes=ax, verbose='error')
evoked, inv, show=False, axes=ax, verbose='error')
finally:
evoked.nave = orig
if len(all_evoked) > 1:
Expand All @@ -899,7 +899,6 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
% (section, analysis, name,
'/'.join(str(e.nave)
for e in all_evoked)))
figs.tight_layout()
_add_figs_to_section(
report, figs, captions, section=section,
image_format=img_form)
Expand Down Expand Up @@ -1003,9 +1002,6 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
times, **kwargs)
assert isinstance(fig, plt.Figure)
fig.axes[0].set(ylim=(-max_, max_))
t = fig.axes[-1].texts[0]
t.set_text(
f'{t.get_text()}; {n_text})')
else:
fig = plt.figure()
all_figs += [fig]
Expand Down Expand Up @@ -1161,8 +1157,6 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
% (section, analysis, name, extra,
this_evoked.nave,))
captions = ['%2.3f sec' % t for t in times]
print(f'add {repr(title)}')
print(repr(captions))
_add_slider_to_section(
report, figs, captions=captions, section=section,
title=title, image_format=img_form)
Expand Down Expand Up @@ -1235,7 +1229,7 @@ def _proj_fig(fname, info, proj_nums, proj_meg, kind, use_ch, duration):
ch_names = [info['ch_names'][pick]
for pick in mne.pick_types(info, meg=meg, eeg=eeg)]
# Some of these will be missing because of prebads
idx = np.where([np.in1d(ch_names, proj['data']['col_names']).all()
idx = np.where([np.isin(ch_names, proj['data']['col_names']).all()
for proj in projs])[0]
if len(idx) != count:
raise RuntimeError('Expected %d %s projector%s for channel type '
Expand All @@ -1253,17 +1247,18 @@ def _proj_fig(fname, info, proj_nums, proj_meg, kind, use_ch, duration):
for name in ch_names]
proj['data']['data'] = proj['data']['data'][:, sub_idx]
proj['data']['col_names'] = ch_names
topo_axes = [plt.subplot2grid(shape, (ri, ci + cs_trace))
fig = plt.figure(layout="constrained")
topo_axes = [plt.subplot2grid(shape, (ri, ci + cs_trace), fig=fig)
for ci in range(count)]
# topomaps
with warnings.catch_warnings(record=True):
plot_projs_topomap(these_projs, info=info, show=False,
axes=topo_axes)
plot_projs_topomap(
these_projs, info=info, show=False, axes=topo_axes,
)
plt.setp(topo_axes, title='', xlabel='')
unit = mne.defaults.DEFAULTS['units'][ch_type]
if cs_trace:
ax = plt.subplot2grid(shape, (ri, n_col - cs_trace - 1),
colspan=cs_trace)
colspan=cs_trace, fig=fig)
this_evoked = evoked.copy().pick_channels(ch_names)
p = np.concatenate([p['data']['data'] for p in these_projs])
assert p.shape == (len(these_projs), len(this_evoked.data))
Expand All @@ -1274,8 +1269,7 @@ def _proj_fig(fname, info, proj_nums, proj_meg, kind, use_ch, duration):
ch_traces = evoked.copy().pick_channels(use_ch).data
ch_traces -= np.mean(ch_traces, axis=1, keepdims=True)
ch_traces /= np.abs(ch_traces).max()
with warnings.catch_warnings(record=True): # tight_layout
this_evoked.plot(picks='all', axes=[ax])
this_evoked.plot(picks='all', axes=[ax], show=False)
for line in ax.lines:
line.set(lw=0.5, zorder=3)
for t in list(ax.texts):
Expand Down Expand Up @@ -1307,16 +1301,14 @@ def _proj_fig(fname, info, proj_nums, proj_meg, kind, use_ch, duration):
need_legend = False
last_ax[1] = ax
# Before and after traces
ax = plt.subplot2grid(shape, (ri, 0), colspan=cs_trace)
with warnings.catch_warnings(record=True): # tight_layout
this_evoked.plot(
picks='all', axes=[ax])
ax = plt.subplot2grid(shape, (ri, 0), colspan=cs_trace, fig=fig)
this_evoked.plot(picks='all', axes=[ax], show=False)
for line in ax.lines:
line.set(lw=0.5, zorder=3)
loff = len(ax.lines)
with warnings.catch_warnings(record=True): # tight_layout
this_evoked.copy().add_proj(projs).apply_proj().plot(
picks='all', axes=[ax])
e = this_evoked.copy().add_proj(projs)
e.info.normalize_proj()
e.apply_proj().plot(picks='all', axes=[ax], show=False)
for line in ax.lines[loff:]:
line.set(lw=0.5, zorder=4, color='g')
for t in list(ax.texts):
Expand All @@ -1328,7 +1320,6 @@ def _proj_fig(fname, info, proj_nums, proj_meg, kind, use_ch, duration):
if cs_trace:
for ax in last_ax:
ax.set(xlabel='Time (sec)')
fig.subplots_adjust(0.1, 0.15, 1.0, 0.9, wspace=0.25, hspace=0.2)
assert used.all() and (used <= 2).all()
return fig

Expand Down
Loading