diff --git a/mnefun/_forward.py b/mnefun/_forward.py index 7300bc5..c3cbdb2 100644 --- a/mnefun/_forward.py +++ b/mnefun/_forward.py @@ -74,7 +74,7 @@ def gen_forwards(p, subjects, structurals, run_indices): def _get_bem_src_trans(p, info, subj, struc): - subjects_dir = get_subjects_dir(p.subjects_dir, raise_error=True) + subjects_dir = str(get_subjects_dir(p.subjects_dir, raise_error=True)) assert isinstance(subjects_dir, str) if struc is None: # spherical case bem, src, trans = _spherical_conductor(info, subj, p.src_pos) diff --git a/mnefun/_report.py b/mnefun/_report.py index 4d305ba..ebb97b0 100644 --- a/mnefun/_report.py +++ b/mnefun/_report.py @@ -18,7 +18,7 @@ plot_events) from mne.viz._3d import plot_head_positions from mne.report import Report -from mne.utils import _pl, use_log_level +from mne.utils import _pl, use_log_level, logger from mne.cov import whiten_evoked from mne.viz.utils import _triage_rank_sss @@ -48,17 +48,18 @@ def report_context(): plt.ioff() old_backend = matplotlib.get_backend() matplotlib.use('Agg', force=True) - try: - with plt.style.context(style): - yield - except Exception: - matplotlib.use(old_backend, force=True) - plt.interactive(is_interactive) - raise + with mne.viz.use_browser_backend('matplotlib'): + try: + with plt.style.context(style): + yield + except Exception: + matplotlib.use(old_backend, force=True) + plt.interactive(is_interactive) + raise # Backward compat wrappers for MNE 1.0+ -def _add_figs_to_section(report, figs, captions, section, image_format='png'): +def _add_figs_to_section(report, figs, captions, section, image_format='webp'): try: report.add_figure except AttributeError: @@ -82,7 +83,7 @@ def _add_figs_to_section(report, figs, captions, section, image_format='png'): def _add_slider_to_section(report, figs, captions, section, title, - image_format='png'): + image_format='webp'): try: report.add_figure except AttributeError: @@ -113,7 +114,7 @@ def _check_fname_raw(fname, p, subj): return fname, raw -def _report_good_hpi(report, fnames, p=None, subj=None): +def _report_good_hpi(report, fnames, p=None, subj=None, img_form='webp'): t0 = time.time() section = 'Good HPI count' print((' %s ... ' % section).ljust(LJUST), end='') @@ -132,11 +133,11 @@ def _report_good_hpi(report, fnames, p=None, subj=None): figs.append(fig) captions.append('%s: %s' % (section, op.basename(fname))) _add_figs_to_section( - report, figs, captions, section, image_format='png') + report, figs, captions, section, image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) -def _report_chpi_snr(report, fnames, p=None, subj=None): +def _report_chpi_snr(report, fnames, p=None, subj=None, img_form='webp'): t0 = time.time() section = 'cHPI SNR' print((' %s ... ' % section).ljust(LJUST), end='') @@ -155,11 +156,12 @@ def _report_chpi_snr(report, fnames, p=None, subj=None): figs.append(fig) captions.append('%s: %s' % (section, op.basename(fname))) _add_figs_to_section( - report, figs, captions, section, image_format='png') # svd too slow + report, figs, captions, section, image_format=img_form) # svd too slow print('%5.1f sec' % ((time.time() - t0),)) -def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None): +def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None, + img_form='webp'): section = 'Head movement' print((' %s ... ' % section).ljust(LJUST), end='') t0 = time.time() @@ -200,11 +202,12 @@ def _report_head_movement(report, fnames, p=None, subj=None, run_indices=None): figs.append(fig) captions.append('%s: %s' % (section, op.basename(fname))) del trans_to - _add_figs_to_section(report, figs, captions, section, image_format='png') + _add_figs_to_section(report, figs, captions, section, + image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) -def _report_events(report, fnames, p=None, subj=None): +def _report_events(report, fnames, p=None, subj=None, img_form='webp'): t0 = time.time() section = 'Events' print((' %s ... ' % section).ljust(LJUST), end='') @@ -222,11 +225,11 @@ def _report_events(report, fnames, p=None, subj=None): captions.append('%s: %s' % (section, op.basename(fname))) if len(figs): _add_figs_to_section( - report, figs, captions, section, image_format='png') + report, figs, captions, section, image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) -def _report_raw_segments(report, raw, lowpass=None): +def _report_raw_segments(report, raw, lowpass=None, img_form='webp'): t0 = time.time() section = 'Raw segments' print((' %s ... ' % section).ljust(LJUST), end='') @@ -263,12 +266,21 @@ def _report_raw_segments(report, raw, lowpass=None): fig.set(figheight=(fig.axes[0].get_yticks() != 0).sum(), figwidth=12) fig.subplots_adjust(0.025, 0.0, 1, 1, 0, 0) - _add_figs_to_section(report, fig, section, section, image_format='png') + _add_figs_to_section(report, fig, section, section, image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) +def _gen_psd_plot(raw, fmax, n_fft, ax): + if hasattr(raw, 'compute_psd'): + plot = raw.compute_psd(fmax=fmax, n_fft=n_fft).plot(show=False, + axes=ax) + else: + plot = raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax) + return plot + + def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None, - p=None): + p=None, img_form='webp'): t0 = time.time() section = 'PSD' import matplotlib.pyplot as plt @@ -283,7 +295,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None, fmax = raw.info['lowpass'] n_ax = sum(key in raw for key in ('mag', 'grad', 'eeg')) _, ax = plt.subplots(n_ax, figsize=(10, 8)) - figs = [raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax)] + figs = [_gen_psd_plot(raw, fmax=fmax, n_fft=n_fft, ax=ax)] captions = ['%s: Raw' % section] fmax = lp_cut + 2 * lp_trans for this_raw, caption in [ @@ -293,8 +305,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None, (raw_erm_pca, f'{section}: ERM processed (zoomed)')]: _, ax = plt.subplots(n_ax, figsize=(10, 8)) if this_raw is not None: - figs.append( - this_raw.plot_psd(fmax=fmax, n_fft=n_fft, show=False, ax=ax)) + figs.append(_gen_psd_plot(this_raw, fmax=fmax, n_fft=n_fft, ax=ax)) captions.append(caption) # shared y limits n = len(figs[0].axes) // 2 @@ -306,7 +317,7 @@ def _report_raw_psd(report, raw, raw_pca=None, raw_erm=None, raw_erm_pca=None, ax.set_ylim(ylims) ax.set(title='') _add_figs_to_section( - report, figs, captions, section, image_format='png') + report, figs, captions, section, image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) @@ -375,9 +386,6 @@ def gen_html_report(p, subjects, structurals, run_indices=None): import matplotlib.pyplot as plt if run_indices is None: run_indices = [None] * len(subjects) - time_kwargs = dict() - if 'time_unit' in mne.fixes._get_args(mne.viz.plot_evoked): - time_kwargs['time_unit'] = 's' known_keys = { 'good_hpi_count', 'chpi_snr', 'head_movement', 'raw_segments', 'psd', 'ssp_topomaps', 'source_alignment', 'drop_log', 'bem', 'covariance', @@ -445,6 +453,12 @@ def gen_html_report(p, subjects, structurals, run_indices=None): subj + p.inv_tag + '-fwd.fif')) with report_context(): + # + # Set report image format + # + img_form = report.image_format + logger.info(f'Setting default image format to {img_form}.') + # # Custom pre-fun # @@ -459,7 +473,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): # Head coils # if p.report_params.get('good_hpi_count', True) and p.movecomp: - _report_good_hpi(report, fnames, p, subj) + _report_good_hpi(report, fnames, p, subj, img_form=img_form) else: print(' HPI count skipped') @@ -467,7 +481,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): # cHPI SNR # if p.report_params.get('chpi_snr', True) and p.movecomp: - _report_chpi_snr(report, fnames, p, subj) + _report_chpi_snr(report, fnames, p, subj, img_form=img_form) else: print(' cHPI SNR skipped') @@ -475,7 +489,8 @@ def gen_html_report(p, subjects, structurals, run_indices=None): # Head movement # if p.report_params.get('head_movement', True) and p.movecomp: - _report_head_movement(report, fnames, p, subj, run_indices[si]) + _report_head_movement(report, fnames, p, subj, run_indices[si], + img_form=img_form) else: print(' Head movement skipped') @@ -484,7 +499,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): # if p.report_params.get('raw_segments', True) and \ raw_pca is not None: - _report_raw_segments(report, raw_pca) + _report_raw_segments(report, raw_pca, img_form=img_form) else: print(' Raw segments skipped') @@ -492,7 +507,8 @@ def gen_html_report(p, subjects, structurals, run_indices=None): # PSD # if p.report_params.get('psd', True): - _report_raw_psd(report, raw, raw_pca, raw_erm, raw_erm_pca, p) + _report_raw_psd(report, raw, raw_pca, raw_erm, raw_erm_pca, p, + img_form=img_form) else: print(' PSD skipped') @@ -553,7 +569,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): duration)) captions = ['SSP epochs: %s' % c for c in comments] _add_figs_to_section( - report, figs, captions, section, image_format='png') + report, figs, captions, section, image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) else: print(' %s skipped' % section) @@ -638,7 +654,8 @@ def gen_html_report(p, subjects, structurals, run_indices=None): from mayavi import mlab mlab.close(fig) view = trim_bg(np.concatenate(view, axis=1), 0) - _add_figs_to_section(report, [view], captions, section) + _add_figs_to_section(report, [view], captions, section, + image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) else: print(' %s skipped' % section) @@ -654,7 +671,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): figs = [epo.plot_drop_log(subject=subj, show=False)] captions = [repr(epo)] _add_figs_to_section( - report, figs, captions, section, image_format='svg') + report, figs, captions, section, image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) else: print(' %s skipped' % section) @@ -711,7 +728,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): for kind in ('images', 'SVDs')] _add_figs_to_section( report, figs, captions, section=section, - image_format='png') + image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) else: print(' %s skipped' % section) @@ -813,7 +830,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): figs = [figs] _add_figs_to_section( report, figs, captions, section=section, - image_format='png') + image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) else: print(' %s skipped' % section) @@ -885,7 +902,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): figs.tight_layout() _add_figs_to_section( report, figs, captions, section=section, - image_format='png') + image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) # @@ -968,13 +985,22 @@ def gen_html_report(p, subjects, structurals, run_indices=None): this_evoked.add_proj(all_proj) if this_evoked.nave > 0: with mne.utils.use_log_level('error'): - fig = this_evoked.plot_joint( - times, show=False, picks=picks, - ts_args=dict(proj=proj), - topomap_args=dict( - outlines='head', - vmin=min_, vmax=max_, - cmap=cmap, proj=proj)) + topomap_args = dict(outlines='head', + vlim=(min_, max_), + cmap=cmap, proj=proj) + kwargs = dict(show=False, picks=picks, + ts_args=dict(proj=proj), + topomap_args=topomap_args) + try: + fig = this_evoked.plot_joint( + times, **kwargs) + except TypeError: + # old MNE had separate vmin, vmax + topomap_args.update( + zip(('vmin', 'vmax'), + topomap_args.pop('vlim'))) + fig = this_evoked.plot_joint( + times, **kwargs) assert isinstance(fig, plt.Figure) fig.axes[0].set(ylim=(-max_, max_)) t = fig.axes[-1].texts[0] @@ -993,7 +1019,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): title += ' : SSP on' _add_slider_to_section( report, all_figs, all_captions, section=section, - title=title, image_format='png') + title=title, image_format=img_form) del this_evoked, all_evoked print('%5.1f sec' % ((time.time() - t0),)) @@ -1139,7 +1165,7 @@ def gen_html_report(p, subjects, structurals, run_indices=None): print(repr(captions)) _add_slider_to_section( report, figs, captions=captions, section=section, - title=title, image_format='png') + title=title, image_format=img_form) print('%5.1f sec' % ((time.time() - t0),)) else: