Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mnefun._reports to work with current mne-python #364

Merged
merged 11 commits into from
Apr 4, 2023
2 changes: 1 addition & 1 deletion mnefun/_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def gen_forwards(p, subjects, structurals, run_indices):


def _get_bem_src_trans(p, info, subj, struc):
subjects_dir = get_subjects_dir(p.subjects_dir, raise_error=True)
subjects_dir = str(get_subjects_dir(p.subjects_dir, raise_error=True))
assert isinstance(subjects_dir, str)
if struc is None: # spherical case
bem, src, trans = _spherical_conductor(info, subj, p.src_pos)
Expand Down
122 changes: 74 additions & 48 deletions mnefun/_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
plot_events)
from mne.viz._3d import plot_head_positions
from mne.report import Report
from mne.utils import _pl, use_log_level
from mne.utils import _pl, use_log_level, logger
from mne.cov import whiten_evoked
from mne.viz.utils import _triage_rank_sss

Expand Down Expand Up @@ -48,17 +48,18 @@ def report_context():
plt.ioff()
old_backend = matplotlib.get_backend()
matplotlib.use('Agg', force=True)
try:
with plt.style.context(style):
yield
except Exception:
matplotlib.use(old_backend, force=True)
plt.interactive(is_interactive)
raise
with mne.viz.use_browser_backend('matplotlib'):
try:
with plt.style.context(style):
yield
except Exception:
matplotlib.use(old_backend, force=True)
plt.interactive(is_interactive)
raise


# Backward compat wrappers for MNE 1.0+
def _add_figs_to_section(report, figs, captions, section, image_format='png'):
def _add_figs_to_section(report, figs, captions, section, image_format='webp'):
try:
report.add_figure
except AttributeError:
Expand All @@ -82,7 +83,7 @@ def _add_figs_to_section(report, figs, captions, section, image_format='png'):


def _add_slider_to_section(report, figs, captions, section, title,
image_format='png'):
image_format='webp'):
try:
report.add_figure
except AttributeError:
Expand Down Expand Up @@ -113,7 +114,7 @@ def _check_fname_raw(fname, p, subj):
return fname, raw


def _report_good_hpi(report, fnames, p=None, subj=None):
def _report_good_hpi(report, fnames, p=None, subj=None, img_form='webp'):
t0 = time.time()
section = 'Good HPI count'
print((' %s ... ' % section).ljust(LJUST), end='')
Expand All @@ -132,11 +133,11 @@ def _report_good_hpi(report, fnames, p=None, subj=None):
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
_add_figs_to_section(
report, figs, captions, section, image_format='png')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


def _report_chpi_snr(report, fnames, p=None, subj=None):
def _report_chpi_snr(report, fnames, p=None, subj=None, img_form='webp'):
t0 = time.time()
section = 'cHPI SNR'
print((' %s ... ' % section).ljust(LJUST), end='')
Expand All @@ -155,11 +156,12 @@ def _report_chpi_snr(report, fnames, p=None, subj=None):
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
_add_figs_to_section(
report, figs, captions, section, image_format='png') # svd too slow
report, figs, captions, section, image_format=img_form) # svd too slow
print('%5.1f sec' % ((time.time() - t0),))


def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None):
def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None,
img_form='webp'):
section = 'Head movement'
print((' %s ... ' % section).ljust(LJUST), end='')
t0 = time.time()
Expand Down Expand Up @@ -200,11 +202,12 @@ def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None):
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
del trans_to
_add_figs_to_section(report, figs, captions, section, image_format='png')
_add_figs_to_section(report, figs, captions, section,
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


def _report_events(report, fnames, p=None, subj=None):
def _report_events(report, fnames, p=None, subj=None, img_form='webp'):
t0 = time.time()
section = 'Events'
print((' %s ... ' % section).ljust(LJUST), end='')
Expand All @@ -222,11 +225,11 @@ def _report_events(report, fnames, p=None, subj=None):
captions.append('%s: %s' % (section, op.basename(fname)))
if len(figs):
_add_figs_to_section(
report, figs, captions, section, image_format='png')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


def _report_raw_segments(report, raw, lowpass=None):
def _report_raw_segments(report, raw, lowpass=None, img_form='webp'):
t0 = time.time()
section = 'Raw segments'
print((' %s ... ' % section).ljust(LJUST), end='')
Expand Down Expand Up @@ -263,12 +266,21 @@ def _report_raw_segments(report, raw, lowpass=None):
fig.set(figheight=(fig.axes[0].get_yticks() != 0).sum(),
figwidth=12)
fig.subplots_adjust(0.025, 0.0, 1, 1, 0, 0)
_add_figs_to_section(report, fig, section, section, image_format='png')
_add_figs_to_section(report, fig, section, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


def _gen_psd_plot(raw, fmax, n_fft, ax):
if hasattr(mne.time_frequency.Spectrum, 'plot'):
plot = raw.compute_psd(fmax=fmax, n_fft=n_fft).plot(show=False,
axes=ax)
drammock marked this conversation as resolved.
Show resolved Hide resolved
else:
plot = raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax)
return plot


def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
p=None):
p=None, img_form='webp'):
t0 = time.time()
section = 'PSD'
import matplotlib.pyplot as plt
Expand All @@ -283,7 +295,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
fmax = raw.info['lowpass']
n_ax = sum(key in raw for key in ('mag', 'grad', 'eeg'))
_, ax = plt.subplots(n_ax, figsize=(10, 8))
figs = [raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax)]
figs = [_gen_psd_plot(raw, fmax=fmax, n_fft=n_fft, ax=ax)]
captions = ['%s: Raw' % section]
fmax = lp_cut + 2 * lp_trans
for this_raw, caption in [
Expand All @@ -293,8 +305,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
(raw_erm_pca, f'{section}: ERM processed (zoomed)')]:
_, ax = plt.subplots(n_ax, figsize=(10, 8))
if this_raw is not None:
figs.append(
this_raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax))
figs.append(_gen_psd_plot(this_raw, fmax=fmax, n_fft=n_fft, ax=ax))
captions.append(caption)
# shared y limits
n = len(figs[0].axes) // 2
Expand All @@ -306,7 +317,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
ax.set_ylim(ylims)
ax.set(title='')
_add_figs_to_section(
report, figs, captions, section, image_format='png')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


Expand Down Expand Up @@ -375,9 +386,6 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
import matplotlib.pyplot as plt
if run_indices is None:
run_indices = [None] * len(subjects)
time_kwargs = dict()
if 'time_unit' in mne.fixes._get_args(mne.viz.plot_evoked):
time_kwargs['time_unit'] = 's'
known_keys = {
'good_hpi_count', 'chpi_snr', 'head_movement', 'raw_segments', 'psd',
'ssp_topomaps', 'source_alignment', 'drop_log', 'bem', 'covariance',
Expand Down Expand Up @@ -445,6 +453,12 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
subj + p.inv_tag + '-fwd.fif'))

with report_context():
#
# Set report image format
#
img_form = report.image_format
logger.info(f'Setting default image format to {img_form}.')

#
# Custom pre-fun
#
Expand All @@ -459,23 +473,24 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
# Head coils
#
if p.report_params.get('good_hpi_count', True) and p.movecomp:
_report_good_hpi(report, fnames, p, subj)
_report_good_hpi(report, fnames, p, subj, img_form=img_form)
else:
print(' HPI count skipped')

#
# cHPI SNR
#
if p.report_params.get('chpi_snr', True) and p.movecomp:
_report_chpi_snr(report, fnames, p, subj)
_report_chpi_snr(report, fnames, p, subj, img_form=img_form)
else:
print(' cHPI SNR skipped')

#
# Head movement
#
if p.report_params.get('head_movement', True) and p.movecomp:
_report_head_movement(report, fnames, p, subj, run_indices[si])
_report_head_movement(report, fnames, p, subj, run_indices[si],
img_form=img_form)
else:
print(' Head movement skipped')

Expand All @@ -484,15 +499,16 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
#
if p.report_params.get('raw_segments', True) and \
raw_pca is not None:
_report_raw_segments(report, raw_pca)
_report_raw_segments(report, raw_pca, img_form=img_form)
else:
print(' Raw segments skipped')

#
# PSD
#
if p.report_params.get('psd', True):
_report_raw_psd(report, raw, raw_pca, raw_erm, raw_erm_pca, p)
_report_raw_psd(report, raw, raw_pca, raw_erm, raw_erm_pca, p,
img_form=img_form)
else:
print(' PSD skipped')

Expand Down Expand Up @@ -553,7 +569,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
duration))
captions = ['SSP epochs: %s' % c for c in comments]
_add_figs_to_section(
report, figs, captions, section, image_format='png')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand Down Expand Up @@ -638,7 +654,8 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
from mayavi import mlab
mlab.close(fig)
view = trim_bg(np.concatenate(view, axis=1), 0)
_add_figs_to_section(report, [view], captions, section)
_add_figs_to_section(report, [view], captions, section,
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand All @@ -654,7 +671,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
figs = [epo.plot_drop_log(subject=subj, show=False)]
captions = [repr(epo)]
_add_figs_to_section(
report, figs, captions, section, image_format='svg')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand Down Expand Up @@ -711,7 +728,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
for kind in ('images', 'SVDs')]
_add_figs_to_section(
report, figs, captions, section=section,
image_format='png')
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand Down Expand Up @@ -813,7 +830,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
figs = [figs]
_add_figs_to_section(
report, figs, captions, section=section,
image_format='png')
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand Down Expand Up @@ -885,7 +902,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
figs.tight_layout()
_add_figs_to_section(
report, figs, captions, section=section,
image_format='png')
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))

#
Expand Down Expand Up @@ -968,13 +985,22 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
this_evoked.add_proj(all_proj)
if this_evoked.nave > 0:
with mne.utils.use_log_level('error'):
fig = this_evoked.plot_joint(
times, show=False, picks=picks,
ts_args=dict(proj=proj),
topomap_args=dict(
outlines='head',
vmin=min_, vmax=max_,
cmap=cmap, proj=proj))
try:
fig = this_evoked.plot_joint(
times, show=False, picks=picks,
ts_args=dict(proj=proj),
topomap_args=dict(
outlines='head',
vlim=(min_, max_),
cmap=cmap, proj=proj))
except TypeError:
fig = this_evoked.plot_joint(
times, show=False, picks=picks,
ts_args=dict(proj=proj),
topomap_args=dict(
outlines='head',
vmin=min_, vmax=max_,
cmap=cmap, proj=proj))
drammock marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(fig, plt.Figure)
fig.axes[0].set(ylim=(-max_, max_))
t = fig.axes[-1].texts[0]
Expand All @@ -993,7 +1019,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
title += ' : SSP on'
_add_slider_to_section(
report, all_figs, all_captions, section=section,
title=title, image_format='png')
title=title, image_format=img_form)
del this_evoked, all_evoked
print('%5.1f sec' % ((time.time() - t0),))

Expand Down Expand Up @@ -1139,7 +1165,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
print(repr(captions))
_add_slider_to_section(
report, figs, captions=captions, section=section,
title=title, image_format='png')
title=title, image_format=img_form)

print('%5.1f sec' % ((time.time() - t0),))
else:
Expand Down