Skip to content

Commit

Permalink
Update mnefun._reports to work with current mne-python (#364)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel McCloy <dan@mccloy.info>
  • Loading branch information
nordme and drammock authored Apr 4, 2023
1 parent dedf9fa commit df17487
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 49 deletions.
2 changes: 1 addition & 1 deletion mnefun/_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def gen_forwards(p, subjects, structurals, run_indices):


def _get_bem_src_trans(p, info, subj, struc):
subjects_dir = get_subjects_dir(p.subjects_dir, raise_error=True)
subjects_dir = str(get_subjects_dir(p.subjects_dir, raise_error=True))
assert isinstance(subjects_dir, str)
if struc is None: # spherical case
bem, src, trans = _spherical_conductor(info, subj, p.src_pos)
Expand Down
122 changes: 74 additions & 48 deletions mnefun/_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
plot_events)
from mne.viz._3d import plot_head_positions
from mne.report import Report
from mne.utils import _pl, use_log_level
from mne.utils import _pl, use_log_level, logger
from mne.cov import whiten_evoked
from mne.viz.utils import _triage_rank_sss

Expand Down Expand Up @@ -48,17 +48,18 @@ def report_context():
plt.ioff()
old_backend = matplotlib.get_backend()
matplotlib.use('Agg', force=True)
try:
with plt.style.context(style):
yield
except Exception:
matplotlib.use(old_backend, force=True)
plt.interactive(is_interactive)
raise
with mne.viz.use_browser_backend('matplotlib'):
try:
with plt.style.context(style):
yield
except Exception:
matplotlib.use(old_backend, force=True)
plt.interactive(is_interactive)
raise


# Backward compat wrappers for MNE 1.0+
def _add_figs_to_section(report, figs, captions, section, image_format='png'):
def _add_figs_to_section(report, figs, captions, section, image_format='webp'):
try:
report.add_figure
except AttributeError:
Expand All @@ -82,7 +83,7 @@ def _add_figs_to_section(report, figs, captions, section, image_format='png'):


def _add_slider_to_section(report, figs, captions, section, title,
image_format='png'):
image_format='webp'):
try:
report.add_figure
except AttributeError:
Expand Down Expand Up @@ -113,7 +114,7 @@ def _check_fname_raw(fname, p, subj):
return fname, raw


def _report_good_hpi(report, fnames, p=None, subj=None):
def _report_good_hpi(report, fnames, p=None, subj=None, img_form='webp'):
t0 = time.time()
section = 'Good HPI count'
print((' %s ... ' % section).ljust(LJUST), end='')
Expand All @@ -132,11 +133,11 @@ def _report_good_hpi(report, fnames, p=None, subj=None):
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
_add_figs_to_section(
report, figs, captions, section, image_format='png')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


def _report_chpi_snr(report, fnames, p=None, subj=None):
def _report_chpi_snr(report, fnames, p=None, subj=None, img_form='webp'):
t0 = time.time()
section = 'cHPI SNR'
print((' %s ... ' % section).ljust(LJUST), end='')
Expand All @@ -155,11 +156,12 @@ def _report_chpi_snr(report, fnames, p=None, subj=None):
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
_add_figs_to_section(
report, figs, captions, section, image_format='png') # svd too slow
report, figs, captions, section, image_format=img_form) # svd too slow
print('%5.1f sec' % ((time.time() - t0),))


def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None):
def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None,
img_form='webp'):
section = 'Head movement'
print((' %s ... ' % section).ljust(LJUST), end='')
t0 = time.time()
Expand Down Expand Up @@ -200,11 +202,12 @@ def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None):
figs.append(fig)
captions.append('%s: %s' % (section, op.basename(fname)))
del trans_to
_add_figs_to_section(report, figs, captions, section, image_format='png')
_add_figs_to_section(report, figs, captions, section,
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


def _report_events(report, fnames, p=None, subj=None):
def _report_events(report, fnames, p=None, subj=None, img_form='webp'):
t0 = time.time()
section = 'Events'
print((' %s ... ' % section).ljust(LJUST), end='')
Expand All @@ -222,11 +225,11 @@ def _report_events(report, fnames, p=None, subj=None):
captions.append('%s: %s' % (section, op.basename(fname)))
if len(figs):
_add_figs_to_section(
report, figs, captions, section, image_format='png')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


def _report_raw_segments(report, raw, lowpass=None):
def _report_raw_segments(report, raw, lowpass=None, img_form='webp'):
t0 = time.time()
section = 'Raw segments'
print((' %s ... ' % section).ljust(LJUST), end='')
Expand Down Expand Up @@ -263,12 +266,21 @@ def _report_raw_segments(report, raw, lowpass=None):
fig.set(figheight=(fig.axes[0].get_yticks() != 0).sum(),
figwidth=12)
fig.subplots_adjust(0.025, 0.0, 1, 1, 0, 0)
_add_figs_to_section(report, fig, section, section, image_format='png')
_add_figs_to_section(report, fig, section, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


def _gen_psd_plot(raw, fmax, n_fft, ax):
if hasattr(raw, 'compute_psd'):
plot = raw.compute_psd(fmax=fmax, n_fft=n_fft).plot(show=False,
axes=ax)
else:
plot = raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax)
return plot


def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
p=None):
p=None, img_form='webp'):
t0 = time.time()
section = 'PSD'
import matplotlib.pyplot as plt
Expand All @@ -283,7 +295,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
fmax = raw.info['lowpass']
n_ax = sum(key in raw for key in ('mag', 'grad', 'eeg'))
_, ax = plt.subplots(n_ax, figsize=(10, 8))
figs = [raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax)]
figs = [_gen_psd_plot(raw, fmax=fmax, n_fft=n_fft, ax=ax)]
captions = ['%s: Raw' % section]
fmax = lp_cut + 2 * lp_trans
for this_raw, caption in [
Expand All @@ -293,8 +305,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
(raw_erm_pca, f'{section}: ERM processed (zoomed)')]:
_, ax = plt.subplots(n_ax, figsize=(10, 8))
if this_raw is not None:
figs.append(
this_raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax))
figs.append(_gen_psd_plot(this_raw, fmax=fmax, n_fft=n_fft, ax=ax))
captions.append(caption)
# shared y limits
n = len(figs[0].axes) // 2
Expand All @@ -306,7 +317,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None,
ax.set_ylim(ylims)
ax.set(title='')
_add_figs_to_section(
report, figs, captions, section, image_format='png')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))


Expand Down Expand Up @@ -375,9 +386,6 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
import matplotlib.pyplot as plt
if run_indices is None:
run_indices = [None] * len(subjects)
time_kwargs = dict()
if 'time_unit' in mne.fixes._get_args(mne.viz.plot_evoked):
time_kwargs['time_unit'] = 's'
known_keys = {
'good_hpi_count', 'chpi_snr', 'head_movement', 'raw_segments', 'psd',
'ssp_topomaps', 'source_alignment', 'drop_log', 'bem', 'covariance',
Expand Down Expand Up @@ -445,6 +453,12 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
subj + p.inv_tag + '-fwd.fif'))

with report_context():
#
# Set report image format
#
img_form = report.image_format
logger.info(f'Setting default image format to {img_form}.')

#
# Custom pre-fun
#
Expand All @@ -459,23 +473,24 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
# Head coils
#
if p.report_params.get('good_hpi_count', True) and p.movecomp:
_report_good_hpi(report, fnames, p, subj)
_report_good_hpi(report, fnames, p, subj, img_form=img_form)
else:
print(' HPI count skipped')

#
# cHPI SNR
#
if p.report_params.get('chpi_snr', True) and p.movecomp:
_report_chpi_snr(report, fnames, p, subj)
_report_chpi_snr(report, fnames, p, subj, img_form=img_form)
else:
print(' cHPI SNR skipped')

#
# Head movement
#
if p.report_params.get('head_movement', True) and p.movecomp:
_report_head_movement(report, fnames, p, subj, run_indices[si])
_report_head_movement(report, fnames, p, subj, run_indices[si],
img_form=img_form)
else:
print(' Head movement skipped')

Expand All @@ -484,15 +499,16 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
#
if p.report_params.get('raw_segments', True) and \
raw_pca is not None:
_report_raw_segments(report, raw_pca)
_report_raw_segments(report, raw_pca, img_form=img_form)
else:
print(' Raw segments skipped')

#
# PSD
#
if p.report_params.get('psd', True):
_report_raw_psd(report, raw, raw_pca, raw_erm, raw_erm_pca, p)
_report_raw_psd(report, raw, raw_pca, raw_erm, raw_erm_pca, p,
img_form=img_form)
else:
print(' PSD skipped')

Expand Down Expand Up @@ -553,7 +569,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
duration))
captions = ['SSP epochs: %s' % c for c in comments]
_add_figs_to_section(
report, figs, captions, section, image_format='png')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand Down Expand Up @@ -638,7 +654,8 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
from mayavi import mlab
mlab.close(fig)
view = trim_bg(np.concatenate(view, axis=1), 0)
_add_figs_to_section(report, [view], captions, section)
_add_figs_to_section(report, [view], captions, section,
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand All @@ -654,7 +671,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
figs = [epo.plot_drop_log(subject=subj, show=False)]
captions = [repr(epo)]
_add_figs_to_section(
report, figs, captions, section, image_format='svg')
report, figs, captions, section, image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand Down Expand Up @@ -711,7 +728,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
for kind in ('images', 'SVDs')]
_add_figs_to_section(
report, figs, captions, section=section,
image_format='png')
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand Down Expand Up @@ -813,7 +830,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
figs = [figs]
_add_figs_to_section(
report, figs, captions, section=section,
image_format='png')
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))
else:
print(' %s skipped' % section)
Expand Down Expand Up @@ -885,7 +902,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
figs.tight_layout()
_add_figs_to_section(
report, figs, captions, section=section,
image_format='png')
image_format=img_form)
print('%5.1f sec' % ((time.time() - t0),))

#
Expand Down Expand Up @@ -968,13 +985,22 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
this_evoked.add_proj(all_proj)
if this_evoked.nave > 0:
with mne.utils.use_log_level('error'):
fig = this_evoked.plot_joint(
times, show=False, picks=picks,
ts_args=dict(proj=proj),
topomap_args=dict(
outlines='head',
vmin=min_, vmax=max_,
cmap=cmap, proj=proj))
topomap_args = dict(outlines='head',
vlim=(min_, max_),
cmap=cmap, proj=proj)
kwargs = dict(show=False, picks=picks,
ts_args=dict(proj=proj),
topomap_args=topomap_args)
try:
fig = this_evoked.plot_joint(
times, **kwargs)
except TypeError:
# old MNE had separate vmin, vmax
topomap_args.update(
zip(('vmin', 'vmax'),
topomap_args.pop('vlim')))
fig = this_evoked.plot_joint(
times, **kwargs)
assert isinstance(fig, plt.Figure)
fig.axes[0].set(ylim=(-max_, max_))
t = fig.axes[-1].texts[0]
Expand All @@ -993,7 +1019,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
title += ' : SSP on'
_add_slider_to_section(
report, all_figs, all_captions, section=section,
title=title, image_format='png')
title=title, image_format=img_form)
del this_evoked, all_evoked
print('%5.1f sec' % ((time.time() - t0),))

Expand Down Expand Up @@ -1139,7 +1165,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None):
print(repr(captions))
_add_slider_to_section(
report, figs, captions=captions, section=section,
title=title, image_format='png')
title=title, image_format=img_form)

print('%5.1f sec' % ((time.time() - t0),))
else:
Expand Down

0 comments on commit df17487

Please sign in to comment.