diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py index 009671b55..7e5333d50 100644 --- a/fooof/objs/fit.py +++ b/fooof/objs/fit.py @@ -68,7 +68,6 @@ gen_issue_str, gen_width_warning_str) from fooof.plts.fm import plot_fm -from fooof.plts.style import style_spectrum_plot from fooof.utils.data import trim_spectrum from fooof.utils.params import compute_gauss_std from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData @@ -617,12 +616,13 @@ def get_results(self): @copy_doc_func_to_method(plot_fm) def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True, save_fig=False, file_name=None, file_path=None, - ax=None, plot_style=style_spectrum_plot, - data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None): + ax=None, data_kwargs=None, model_kwargs=None, + aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs): - plot_fm(self, plot_peaks, plot_aperiodic, plt_log, add_legend, - save_fig, file_name, file_path, ax, plot_style, - data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs) + plot_fm(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, plt_log=plt_log, + add_legend=add_legend, save_fig=save_fig, file_name=file_name, + file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs, + aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs) @copy_doc_func_to_method(save_report_fm) diff --git a/fooof/objs/group.py b/fooof/objs/group.py index 937563242..064a4bb8c 100644 --- a/fooof/objs/group.py +++ b/fooof/objs/group.py @@ -398,9 +398,9 @@ def get_params(self, name, col=None): @copy_doc_func_to_method(plot_fg) - def plot(self, save_fig=False, file_name=None, file_path=None): + def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): - plot_fg(self, save_fig, file_name, file_path) + plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) @copy_doc_func_to_method(save_report_fg) diff --git a/fooof/plts/annotate.py b/fooof/plts/annotate.py index 04faacf35..093306d42 100644 --- a/fooof/plts/annotate.py +++ b/fooof/plts/annotate.py @@ -7,10 +7,10 @@ from fooof.core.funcs import gaussian_function from fooof.core.modutils import safe_import, check_dependency from fooof.sim.gen import gen_aperiodic -from fooof.plts.utils import check_ax +from fooof.plts.utils import check_ax, savefig from fooof.plts.spectra import plot_spectrum from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS -from fooof.plts.style import check_n_style, style_spectrum_plot +from fooof.plts.style import style_spectrum_plot from fooof.analysis.periodic import get_band_peak_fm from fooof.utils.params import compute_knee_frequency, compute_fwhm @@ -20,16 +20,15 @@ ################################################################################################### ################################################################################################### +@savefig @check_dependency(plt, 'matplotlib') -def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot): +def plot_annotated_peak_search(fm): """Plot a series of plots illustrating the peak search from a flattened spectrum. Parameters ---------- fm : FOOOF FOOOF object, with model fit, data and settings available. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plots. """ # Recalculate the initial aperiodic fit and flattened spectrum that @@ -46,14 +45,12 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot): # This forces the creation of a new plotting axes per iteration ax = check_ax(None, PLT_FIGSIZES['spectral']) - plot_spectrum(fm.freqs, flatspec, ax=ax, plot_style=None, - label='Flattened Spectrum', color=PLT_COLORS['data'], linewidth=2.5) - plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs), - ax=ax, plot_style=None, label='Relative Threshold', - color='orange', linewidth=2.5, linestyle='dashed') - plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs), - ax=ax, plot_style=None, label='Absolute Threshold', - color='red', linewidth=2.5, linestyle='dashed') + plot_spectrum(fm.freqs, flatspec, ax=ax, linewidth=2.5, + label='Flattened Spectrum', color=PLT_COLORS['data']) + plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs), ax=ax, + label='Relative Threshold', color='orange', linewidth=2.5, linestyle='dashed') + plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs), ax=ax, + label='Absolute Threshold', color='red', linewidth=2.5, linestyle='dashed') maxi = np.argmax(flatspec) ax.plot(fm.freqs[maxi], flatspec[maxi], '.', @@ -65,18 +62,18 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot): if ind < fm.n_peaks_: gauss = gaussian_function(fm.freqs, *fm.gaussian_params_[ind, :]) - plot_spectrum(fm.freqs, gauss, ax=ax, plot_style=None, - label='Gaussian Fit', color=PLT_COLORS['periodic'], - linestyle=':', linewidth=3.0) + plot_spectrum(fm.freqs, gauss, ax=ax, label='Gaussian Fit', + color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0) flatspec = flatspec - gauss - check_n_style(plot_style, ax, False, True) + style_spectrum_plot(ax, False, True) +@savefig @check_dependency(plt, 'matplotlib') -def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperiodic=True, - ax=None, plot_style=style_spectrum_plot): +def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, + annotate_aperiodic=True, ax=None): """Plot a an annotated power spectrum and model, from a FOOOF object. Parameters @@ -87,8 +84,6 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio Whether to plot the frequency values in log10 spacing. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plots. Raises ------ @@ -108,7 +103,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio # Create the baseline figure ax = check_ax(ax, PLT_FIGSIZES['spectral']) - fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax, plot_style=None, + fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax, data_kwargs={'lw' : lw1, 'alpha' : 0.6}, aperiodic_kwargs={'lw' : lw1, 'zorder' : 10}, model_kwargs={'lw' : lw1, 'alpha' : 0.5}, @@ -215,7 +210,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio color=PLT_COLORS['aperiodic'], fontsize=fontsize) # Apply style to plot & tune grid styling - check_n_style(plot_style, ax, plt_log, True) + style_spectrum_plot(ax, plt_log, True) ax.grid(True, alpha=0.5) # Add labels to plot in the legend diff --git a/fooof/plts/aperiodic.py b/fooof/plts/aperiodic.py index fafc1dd5e..905b95145 100644 --- a/fooof/plts/aperiodic.py +++ b/fooof/plts/aperiodic.py @@ -3,21 +3,23 @@ from itertools import cycle import numpy as np +import matplotlib.pyplot as plt from fooof.sim.gen import gen_freqs, gen_aperiodic from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_param_plot -from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs +from fooof.plts.style import style_param_plot, style_plot +from fooof.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_aperiodic_params(aps, colors=None, labels=None, - ax=None, plot_style=style_param_plot, **plot_kwargs): +def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs): """Plot aperiodic parameters as dots representing offset and exponent value. Parameters @@ -30,17 +32,14 @@ def plot_aperiodic_params(aps, colors=None, labels=None, Label(s) for plotted data, to be added in a legend. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_param_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to pass into the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) if isinstance(aps, list): - recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels, - plot_style=plot_style, **plot_kwargs) + recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels) else: @@ -48,6 +47,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, xs, ys = aps[:, 0], aps[:, -1] sizes = plot_kwargs.pop('s', 150) + # Create the plot plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7}) ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs) @@ -55,13 +55,15 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax.set_xlabel('Offset') ax.set_ylabel('Exponent') - check_n_style(plot_style, ax) + style_param_plot(ax) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_aperiodic_fits(aps, freq_range, control_offset=False, log_freqs=False, colors=None, labels=None, - ax=None, plot_style=style_param_plot, **plot_kwargs): + ax=None, **plot_kwargs): """Plot reconstructions of model aperiodic fits. Parameters @@ -80,10 +82,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, Label(s) for plotted data, to be added in a legend. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_param_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to pass into the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) @@ -93,11 +93,9 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, if not colors: colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color']) - recursive_plot(aps, plot_function=plot_aperiodic_fits, ax=ax, - freq_range=tuple(freq_range), - control_offset=control_offset, - log_freqs=log_freqs, colors=colors, labels=labels, - plot_style=plot_style, **plot_kwargs) + recursive_plot(aps, plot_aperiodic_fits, ax=ax, freq_range=tuple(freq_range), + control_offset=control_offset, log_freqs=log_freqs, colors=colors, + labels=labels, **plot_kwargs) else: freqs = gen_freqs(freq_range, 0.1) @@ -118,8 +116,7 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, # Recreate & plot the aperiodic component from parameters ap_vals = gen_aperiodic(freqs, ap_params) - plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.35, 'linewidth' : 1.25}) - ax.plot(plt_freqs, ap_vals, color=colors, **plot_kwargs) + ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) # Collect a running average across components avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0) @@ -127,8 +124,7 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, # Plot the average component avg = avg_vals / aps.shape[0] avg_color = 'black' if not colors else colors - ax.plot(plt_freqs, avg, linewidth=plot_kwargs.get('linewidth')*3, - color=avg_color, label=labels) + ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels) # Add axis labels ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency') @@ -137,5 +133,4 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, # Set plot limit ax.set_xlim(np.log10(freq_range) if log_freqs else freq_range) - # Apply plot style - check_n_style(plot_style, ax) + style_param_plot(ax) diff --git a/fooof/plts/error.py b/fooof/plts/error.py index c870900bc..f7cbfdf7b 100644 --- a/fooof/plts/error.py +++ b/fooof/plts/error.py @@ -5,17 +5,18 @@ from fooof.core.modutils import safe_import, check_dependency from fooof.plts.spectra import plot_spectrum from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_spectrum_plot -from fooof.plts.utils import check_ax +from fooof.plts.style import style_spectrum_plot, style_plot +from fooof.plts.utils import check_ax, savefig plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_spectral_error(freqs, error, shade=None, log_freqs=False, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): +def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **plot_kwargs): """Plot frequency by frequency error values. Parameters @@ -31,17 +32,15 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, Whether to plot the frequency axis in log spacing. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to `plot_spectra` or to the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) plt_freqs = np.log10(freqs) if log_freqs else freqs - plot_spectrum(plt_freqs, error, plot_style=None, ax=ax, linewidth=3, **plot_kwargs) + plot_spectrum(plt_freqs, error, ax=ax, linewidth=3) if np.any(shade): ax.fill_between(plt_freqs, error-shade, error+shade, alpha=0.25) @@ -51,5 +50,5 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax.set_ylim([0, ymax]) ax.set_xlim(plt_freqs.min(), plt_freqs.max()) - check_n_style(plot_style, ax, log_freqs, True) + style_spectrum_plot(ax, log_freqs, True) ax.set_ylabel('Absolute Error') diff --git a/fooof/plts/fg.py b/fooof/plts/fg.py index f4d121419..d2f2cc476 100644 --- a/fooof/plts/fg.py +++ b/fooof/plts/fg.py @@ -5,11 +5,12 @@ This file contains plotting functions that take as input a FOOOFGroup object. """ -from fooof.core.io import fname, fpath from fooof.core.errors import NoModelError from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES from fooof.plts.templates import plot_scatter_1, plot_scatter_2, plot_hist +from fooof.plts.utils import savefig +from fooof.plts.style import style_plot plt = safe_import('.pyplot', 'matplotlib') gridspec = safe_import('.gridspec', 'matplotlib') @@ -17,8 +18,9 @@ ################################################################################################### ################################################################################################### +@savefig @check_dependency(plt, 'matplotlib') -def plot_fg(fg, save_fig=False, file_name=None, file_path=None): +def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a FOOOFGroup object. Parameters @@ -44,26 +46,27 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None): fig = plt.figure(figsize=PLT_FIGSIZES['group']) gs = gridspec.GridSpec(2, 2, wspace=0.4, hspace=0.25, height_ratios=[1, 1.2]) + # Apply scatter kwargs to all subplots + scatter_kwargs = plot_kwargs + scatter_kwargs['all_axes'] = True + # Aperiodic parameters plot ax0 = plt.subplot(gs[0, 0]) - plot_fg_ap(fg, ax0) + plot_fg_ap(fg, ax0, **scatter_kwargs) # Goodness of fit plot ax1 = plt.subplot(gs[0, 1]) - plot_fg_gf(fg, ax1) + plot_fg_gf(fg, ax1, **scatter_kwargs) # Center frequencies plot ax2 = plt.subplot(gs[1, :]) - plot_fg_peak_cens(fg, ax2) - - if save_fig: - if not file_name: - raise ValueError("Input 'file_name' is required to save out the plot.") - plt.savefig(fpath(file_path, fname(file_name, 'png'))) + plot_fg_peak_cens(fg, ax2, **plot_kwargs) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_fg_ap(fg, ax=None): +def plot_fg_ap(fg, ax=None, **plot_kwargs): """Plot aperiodic fit parameters, in a scatter plot. Parameters @@ -72,6 +75,8 @@ def plot_fg_ap(fg, ax=None): Object to plot data from. ax : matplotlib.Axes, optional Figure axes upon which to plot. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. """ if fg.aperiodic_mode == 'knee': @@ -83,8 +88,10 @@ def plot_fg_ap(fg, ax=None): 'Aperiodic Fit', ax=ax) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_fg_gf(fg, ax=None): +def plot_fg_gf(fg, ax=None, **plot_kwargs): """Plot goodness of fit results, in a scatter plot. Parameters @@ -93,14 +100,18 @@ def plot_fg_gf(fg, ax=None): Object to plot data from. ax : matplotlib.Axes, optional Figure axes upon which to plot. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. """ plot_scatter_2(fg.get_params('error'), 'Error', fg.get_params('r_squared'), 'R^2', 'Goodness of Fit', ax=ax) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_fg_peak_cens(fg, ax=None): +def plot_fg_peak_cens(fg, ax=None, **plot_kwargs): """Plot peak center frequencies, in a histogram. Parameters @@ -109,6 +120,8 @@ def plot_fg_peak_cens(fg, ax=None): Object to plot data from. ax : matplotlib.Axes, optional Figure axes upon which to plot. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. """ plot_hist(fg.get_params('peak_params', 0)[:, 0], 'Center Frequency', diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py index 68be6ea5b..6674848a3 100644 --- a/fooof/plts/fm.py +++ b/fooof/plts/fm.py @@ -7,7 +7,6 @@ import numpy as np -from fooof.core.io import fname, fpath from fooof.core.utils import nearest_ind from fooof.core.modutils import safe_import, check_dependency from fooof.sim.gen import gen_periodic @@ -15,19 +14,20 @@ from fooof.utils.params import compute_fwhm from fooof.plts.spectra import plot_spectrum from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS -from fooof.plts.utils import check_ax, check_plot_kwargs -from fooof.plts.style import check_n_style, style_spectrum_plot +from fooof.plts.utils import check_ax, check_plot_kwargs, savefig +from fooof.plts.style import style_spectrum_plot, style_plot plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True, - save_fig=False, file_name=None, file_path=None, - ax=None, plot_style=style_spectrum_plot, - data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None): + save_fig=False, file_name=None, file_path=None, ax=None, data_kwargs=None, + model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs): """Plot the power spectrum and model fit results from a FOOOF object. Parameters @@ -51,10 +51,10 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend= Path to directory to save to. If None, saves to current directory. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional Keyword arguments to pass into the plot call for each plot element. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. Notes ----- @@ -70,40 +70,32 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend= # Plot the data, if available if fm.has_data: - data_kwargs = check_plot_kwargs(data_kwargs, \ - {'color' : PLT_COLORS['data'], 'linewidth' : 2.0, - 'label' : 'Original Spectrum' if add_legend else None}) - plot_spectrum(fm.freqs, fm.power_spectrum, log_freqs, log_powers, - ax=ax, plot_style=None, **data_kwargs) + data_defaults = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0, + 'label' : 'Original Spectrum' if add_legend else None} + data_kwargs = check_plot_kwargs(data_kwargs, data_defaults) + plot_spectrum(fm.freqs, fm.power_spectrum, log_freqs, log_powers, ax=ax, **data_kwargs) # Add the full model fit, and components (if requested) if fm.has_model: - model_kwargs = check_plot_kwargs(model_kwargs, \ - {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, - 'label' : 'Full Model Fit' if add_legend else None}) - plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers, - ax=ax, plot_style=None, **model_kwargs) + model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, + 'label' : 'Full Model Fit' if add_legend else None} + model_kwargs = check_plot_kwargs(model_kwargs, model_defaults) + plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit if plot_aperiodic: - aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, \ - {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0, 'alpha' : 0.5, - 'linestyle' : 'dashed', 'label' : 'Aperiodic Fit' if add_legend else None}) - plot_spectrum(fm.freqs, fm._ap_fit, log_freqs, log_powers, - ax=ax, plot_style=None, **aperiodic_kwargs) + aperiodic_defaults = {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0, + 'alpha' : 0.5, 'linestyle' : 'dashed', + 'label' : 'Aperiodic Fit' if add_legend else None} + aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, aperiodic_defaults) + plot_spectrum(fm.freqs, fm._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs) # Plot the periodic components of the model fit if plot_peaks: - _add_peaks(fm, plot_peaks, plt_log, ax=ax, peak_kwargs=peak_kwargs) - - # Apply style to plot - check_n_style(plot_style, ax, log_freqs, True) + _add_peaks(fm, plot_peaks, plt_log, ax, peak_kwargs) - # Save out figure, if requested - if save_fig: - if not file_name: - raise ValueError("Input 'file_name' is required to save out the plot.") - plt.savefig(fpath(file_path, fname(file_name, 'png'))) + # Apply default style to plot + style_spectrum_plot(ax, log_freqs, True) def _add_peaks(fm, approach, plt_log, ax, peak_kwargs): @@ -162,18 +154,18 @@ def _add_peaks_shade(fm, plt_log, ax, **plot_kwargs): ax : matplotlib.Axes Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the plot call. + Keyword arguments to pass into the ``fill_between``. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in fm.get_params('gaussian_params'): peak_freqs = np.log10(fm.freqs) if plt_log else fm.freqs peak_line = fm._ap_fit + gen_periodic(fm.freqs, peak) - ax.fill_between(peak_freqs, peak_line, fm._ap_fit, **kwargs) + ax.fill_between(peak_freqs, peak_line, fm._ap_fit, **plot_kwargs) def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs): @@ -191,9 +183,8 @@ def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the plot call. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in fm.get_params('peak_params'): @@ -201,10 +192,10 @@ def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs): freq_point = np.log10(peak[0]) if plt_log else peak[0] # Add the line from the aperiodic fit up the tip of the peak - ax.plot([freq_point, freq_point], [ap_point, ap_point + peak[1]], **kwargs) + ax.plot([freq_point, freq_point], [ap_point, ap_point + peak[1]], **plot_kwargs) # Add an extra dot at the tip of the peak - ax.plot(freq_point, ap_point + peak[1], marker='o', **kwargs) + ax.plot(freq_point, ap_point + peak[1], marker='o', **plot_kwargs) def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs): @@ -222,9 +213,8 @@ def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the plot call. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.7, 'lw' : 1.5}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.5} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in fm.get_params('gaussian_params'): @@ -237,7 +227,7 @@ def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs): # Plot the peak outline peak_freqs = np.log10(peak_freqs) if plt_log else peak_freqs - ax.plot(peak_freqs, peak_line, **kwargs) + ax.plot(peak_freqs, peak_line, **plot_kwargs) def _add_peaks_line(fm, plt_log, ax, **plot_kwargs): @@ -255,16 +245,15 @@ def _add_peaks_line(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the plot call. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) ylims = ax.get_ylim() for peak in fm.get_params('peak_params'): freq_point = np.log10(peak[0]) if plt_log else peak[0] - ax.plot([freq_point, freq_point], ylims, '-', **kwargs) - ax.plot(freq_point, ylims[1], 'v', **kwargs) + ax.plot([freq_point, freq_point], ylims, '-', **plot_kwargs) + ax.plot(freq_point, ylims[1], 'v', **plot_kwargs) def _add_peaks_width(fm, plt_log, ax, **plot_kwargs): @@ -287,9 +276,8 @@ def _add_peaks_width(fm, plt_log, ax, **plot_kwargs): the peak, though what is literally plotted is the full-width half-max. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in fm.gaussian_params_: @@ -300,7 +288,7 @@ def _add_peaks_width(fm, plt_log, ax, **plot_kwargs): if plt_log: bw_freqs = np.log10(bw_freqs) - ax.plot(bw_freqs, [peak_top-(0.5*peak[1]), peak_top-(0.5*peak[1])], **kwargs) + ax.plot(bw_freqs, [peak_top-(0.5*peak[1]), peak_top-(0.5*peak[1])], **plot_kwargs) # Collect all the possible `add_peak_*` functions together diff --git a/fooof/plts/periodic.py b/fooof/plts/periodic.py index e654bfabc..17e66f1b9 100644 --- a/fooof/plts/periodic.py +++ b/fooof/plts/periodic.py @@ -8,17 +8,18 @@ from fooof.core.funcs import gaussian_function from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_param_plot -from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs +from fooof.plts.style import style_param_plot, style_plot +from fooof.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, - ax=None, plot_style=style_param_plot, **plot_kwargs): +def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs): """Plot peak parameters as dots representing center frequency, power and bandwidth. Parameters @@ -33,18 +34,15 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, Label(s) for plotted data, to be added in a legend. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_param_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to pass into the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) # If there is a list, use recurse function to loop across arrays of data and plot them if isinstance(peaks, list): - recursive_plot(peaks, plot_peak_params, ax, colors=colors, labels=labels, - plot_style=plot_style, **plot_kwargs) + recursive_plot(peaks, plot_peak_params, ax, colors=colors, labels=labels) # Otherwise, plot the array of data else: @@ -66,11 +64,12 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax.set_xlim(freq_range) ax.set_ylim([0, ax.get_ylim()[1]]) - check_n_style(plot_style, ax) + style_param_plot(ax) -def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, - ax=None, plot_style=style_param_plot, **plot_kwargs): +@savefig +@style_plot +def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model peak fits. Parameters @@ -86,8 +85,6 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, Label(s) for plotted data, to be added in a legend. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_param_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs Keyword arguments to pass into the plot call. """ @@ -101,8 +98,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, recursive_plot(peaks, plot_function=plot_peak_fits, ax=ax, freq_range=tuple(freq_range) if freq_range else freq_range, - colors=colors, labels=labels, - plot_style=plot_style, **plot_kwargs) + colors=colors, labels=labels, **plot_kwargs) else: @@ -128,8 +124,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, # Create & plot the peak model from parameters peak_vals = gaussian_function(freqs, *peak_params) - plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.35, 'linewidth' : 1.25}) - ax.plot(freqs, peak_vals, color=colors, **plot_kwargs) + ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) # Collect a running average average peaks avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0) @@ -137,7 +132,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, # Plot the average across all components avg = avg_vals / peaks.shape[0] avg_color = 'black' if not colors else colors - ax.plot(freqs, avg, color=avg_color, linewidth=plot_kwargs.get('linewidth')*3, label=labels) + ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels) # Add axis labels ax.set_xlabel('Frequency') @@ -148,4 +143,4 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax.set_ylim([0, ax.get_ylim()[1]]) # Apply plot style - check_n_style(plot_style, ax) + style_param_plot(ax) diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py index 94d525054..4b7f1050e 100644 --- a/fooof/plts/settings.py +++ b/fooof/plts/settings.py @@ -26,3 +26,27 @@ PLT_ALIASES = {'linewidth' : ['lw', 'linewidth'], 'markersize' : ['ms', 'markersize'], 'linestyle' : ['ls', 'linestyle']} + +# Plot style arguments are those that can be defined on an axis object +AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim'] + +# Line style arguments are those that can be defined on a line object +LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle', + 'marker', 'ms', 'markersize'] + +# Collection style arguments are those that can be defined on a collections object +COLLECTION_STYLE_ARGS = ['alpha', 'edgecolor'] + +# Custom style arguments are those that are custom-handled by the plot style function +CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize', + 'legend_size', 'legend_loc'] +STYLERS = ['axis_styler', 'line_styler', 'custom_styler'] +STYLE_ARGS = AXIS_STYLE_ARGS + LINE_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS + +## Define default values for plot aesthetics +# These are all custom style arguments +TITLE_FONTSIZE = 20 +LABEL_SIZE = 16 +TICK_LABELSIZE = 16 +LEGEND_SIZE = 12 +LEGEND_LOC = 'best' diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py index eaeafdbe7..9c9f4ab53 100644 --- a/fooof/plts/spectra.py +++ b/fooof/plts/spectra.py @@ -5,23 +5,25 @@ This file contains functions for plotting power spectra, that take in data directly. """ -from itertools import repeat +from itertools import repeat, cycle import numpy as np from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_spectrum_plot -from fooof.plts.utils import check_ax, add_shades, check_plot_kwargs +from fooof.plts.style import style_spectrum_plot, style_plot +from fooof.plts.utils import check_ax, add_shades, savefig plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): + color=None, label=None, ax=None, **plot_kwargs): """Plot a power spectrum. Parameters @@ -34,12 +36,14 @@ def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False, Whether to plot the frequency axis in log spacing. log_powers : bool, optional, default: False Whether to plot the power axis in log spacing. + label : str, optional, default: None + Legend label for the spectrum. + color : str, optional, default: None + Line color of the spectrum. ax : matplotlib.Axes, optional Figure axis upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) @@ -49,15 +53,16 @@ def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False, plt_powers = np.log10(power_spectrum) if log_powers else power_spectrum # Create the plot - plot_kwargs = check_plot_kwargs(plot_kwargs, {'linewidth' : 2.0}) - ax.plot(plt_freqs, plt_powers, **plot_kwargs) + ax.plot(plt_freqs, plt_powers, linewidth=2.0, color=color, label=label) - check_n_style(plot_style, ax, log_freqs, log_powers) + style_spectrum_plot(ax, log_freqs, log_powers) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels=None, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): +def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, + colors=None, labels=None, ax=None, **plot_kwargs): """Plot multiple power spectra on the same plot. Parameters @@ -70,32 +75,35 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels Whether to plot the frequency axis in log spacing. log_powers : bool, optional, default: False Whether to plot the power axis in log spacing. - labels : list of str, optional - Legend labels, for each power spectrum. + labels : list of str, optional, default: None + Legend labels for the spectra. + colors : list of str, optional, default: None + Line colors of the spectra. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) # Make inputs iterable if need to be passed multiple times to plot each spectrum freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs - labels = repeat(labels) if not isinstance(labels, list) else labels - for freq, power_spectrum, label in zip(freqs, power_spectra, labels): - plot_spectrum(freq, power_spectrum, log_freqs, log_powers, label=label, - plot_style=None, ax=ax, **plot_kwargs) + colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) + labels = repeat(labels) if not isinstance(labels, list) else cycle(labels) - check_n_style(plot_style, ax, log_freqs, log_powers) + for freq, power_spectrum, color, label in zip(freqs, power_spectra, colors, labels): + plot_spectrum(freq, power_spectrum, log_freqs, log_powers, + color=color, label=label, ax=ax) + style_spectrum_plot(ax, log_freqs, log_powers) + +@savefig @check_dependency(plt, 'matplotlib') -def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_center=False, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): +def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', + add_center=False, ax=None, **plot_kwargs): """Plot a power spectrum with a shaded frequency region (or regions). Parameters @@ -112,26 +120,24 @@ def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_c Whether to add a line at the center point of the shaded regions. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to the plot call. + Keyword arguments to pass into :func:`~.plot_spectrum`. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) - plot_spectrum(freqs, power_spectrum, plot_style=None, ax=ax, **plot_kwargs) + plot_spectrum(freqs, power_spectrum, ax=ax, **plot_kwargs) add_shades(ax, shades, shade_colors, add_center, plot_kwargs.get('log_freqs', False)) - check_n_style(plot_style, ax, - plot_kwargs.get('log_freqs', False), - plot_kwargs.get('log_powers', False)) + style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False), + plot_kwargs.get('log_powers', False)) +@savefig @check_dependency(plt, 'matplotlib') -def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_center=False, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): +def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', + add_center=False, ax=None, **plot_kwargs): """Plot a group of power spectra with a shaded frequency region (or regions). Parameters @@ -148,10 +154,8 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_cen Whether to add a line at the center point of the shaded regions. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to `plot_spectra` or to the plot call. + Keyword arguments to pass into :func:`~.plot_spectra`. Notes ----- @@ -162,10 +166,9 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_cen ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) - plot_spectra(freqs, power_spectra, ax=ax, plot_style=None, **plot_kwargs) + plot_spectra(freqs, power_spectra, ax=ax, **plot_kwargs) add_shades(ax, shades, shade_colors, add_center, plot_kwargs.get('log_freqs', False)) - check_n_style(plot_style, ax, - plot_kwargs.get('log_freqs', False), - plot_kwargs.get('log_powers', False)) + style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False), + plot_kwargs.get('log_powers', False)) diff --git a/fooof/plts/style.py b/fooof/plts/style.py index f9dbcfe80..dc72a1142 100644 --- a/fooof/plts/style.py +++ b/fooof/plts/style.py @@ -1,22 +1,16 @@ """Style and aesthetics definitions for plots.""" -################################################################################################### -################################################################################################### +from itertools import cycle +from functools import wraps -def check_n_style(style_func, *args): - """"Check if a style function has been passed, and apply it to a plot if so. +import matplotlib.pyplot as plt - Parameters - ---------- - style_func : callable or None - Function to apply styling to a plot axis. - *args - Inputs to the style plot. - """ - - if style_func: - style_func(*args) +from fooof.plts.settings import (AXIS_STYLE_ARGS, LINE_STYLE_ARGS, COLLECTION_STYLE_ARGS, + CUSTOM_STYLE_ARGS, STYLE_ARGS, LABEL_SIZE, LEGEND_SIZE, + LEGEND_LOC, TICK_LABELSIZE, TITLE_FONTSIZE) +################################################################################################### +################################################################################################### def style_spectrum_plot(ax, log_freqs, log_powers): """Apply style and aesthetics to a power spectrum plot. @@ -75,3 +69,184 @@ def style_param_plot(ax): legend = ax.legend(prop={'size': 16}) for handle in legend.legendHandles: handle._sizes = [100] + + +def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs): + """Apply axis plot style. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + style_args : list of str + A list of arguments to be sub-selected from `kwargs` and applied as axis styling. + **kwargs + Keyword arguments that define plot style to apply. + """ + + # Apply any provided axis style arguments + plot_kwargs = {key : val for key, val in kwargs.items() if key in style_args} + ax.set(**plot_kwargs) + + +def apply_line_style(ax, style_args=LINE_STYLE_ARGS, **kwargs): + """Apply line plot style. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + style_args : list of str + A list of arguments to be sub-selected from `kwargs` and applied as line styling. + **kwargs + Keyword arguments that define line style to apply. + """ + + # Check how many lines are from the current plot call, to apply style to + # If available, this indicates the apply styling to the last 'n' lines + n_lines_apply = kwargs.pop('n_lines_apply', 0) + + # Get the line related styling arguments from the keyword arguments + line_kwargs = {key : val for key, val in kwargs.items() if key in style_args} + + # Apply any provided line style arguments + for style, value in line_kwargs.items(): + + # Values should be either a single value, for all lines, or a list, of a value per line + # This line checks type, and makes a cycle-able / loop-able object out of the values + values = cycle([value] if isinstance(value, (int, float, str)) else value) + for line in ax.lines[-n_lines_apply:]: + line.set(**{style : next(values)}) + + +def apply_collection_style(ax, style_args=COLLECTION_STYLE_ARGS, **kwargs): + """Apply collection plot style. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + style_args : list of str + A list of arguments to be sub-selected from `kwargs` and applied as collection styling. + **kwargs + Keyword arguments that define collection style to apply. + """ + + # Get the collection related styling arguments from the keyword arguments + collection_kwargs = {key : val for key, val in kwargs.items() if key in style_args} + + # Apply any provided collection style arguments + for collection in ax.collections: + collection.set(**collection_kwargs) + + +def apply_custom_style(ax, **kwargs): + """Apply custom plot style. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + **kwargs + Keyword arguments that define custom style to apply. + """ + + # If a title was provided, update the size + if ax.get_title(): + ax.title.set_size(kwargs.pop('title_fontsize', TITLE_FONTSIZE)) + + # Settings for the axis labels + label_size = kwargs.pop('label_size', LABEL_SIZE) + ax.xaxis.label.set_size(label_size) + ax.yaxis.label.set_size(label_size) + + # Settings for the axis ticks + ax.tick_params(axis='both', which='major', + labelsize=kwargs.pop('tick_labelsize', TICK_LABELSIZE)) + + # If labels were provided, add a legend + if ax.get_legend_handles_labels()[0]: + ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)}, + loc=kwargs.pop('legend_loc', LEGEND_LOC)) + + plt.tight_layout() + + +def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style, + collection_styler=apply_collection_style, custom_styler=apply_custom_style, + **kwargs): + """Apply plot style to a figure axis. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + axis_styler, line_styler, collection_style, custom_styler : callable, optional + Functions to apply style to aspects of the plot. + **kwargs + Keyword arguments that define style to apply. + + Notes + ----- + This function wraps sub-functions which apply style to different plot elements. + Each of these sub-functions can be replaced by passing in replacement callables. + """ + + axis_styler(ax, **kwargs) + line_styler(ax, **kwargs) + collection_styler(ax, **kwargs) + custom_styler(ax, **kwargs) + + +def style_plot(func, *args, **kwargs): + """Decorator function to apply a plot style function, after plot generation. + + Parameters + ---------- + func : callable + The plotting function for creating a plot. + *args, **kwargs + Arguments & keyword arguments. + These should include any arguments for the plot, and those for applying plot style. + + Notes + ----- + This decorator works by: + + - catching all inputs that relate to plot style + - creating a plot, using the passed in plotting function & passing in all non-style arguments + - passing the style related arguments into a `apply_style` function which applies plot styling + + By default, this function applies styling with the `apply_style` function. Custom + functions for applying style can be passed in using `apply_style` as a keyword argument. + + The `apply_style` function calls sub-functions for applying style different plot elements, + and these sub-functions can be overridden by passing in alternatives for `axis_styler`, + `line_styler`, and `custom_styler`. + """ + + @wraps(func) + def decorated(*args, **kwargs): + + # Grab a custom style function, if provided, and grab any provided style arguments + style_func = kwargs.pop('plot_style', apply_style) + style_args = kwargs.pop('style_args', STYLE_ARGS) + style_kwargs = {key : kwargs.pop(key) for key in style_args if key in kwargs} + + # Check how many lines are already on the plot, if it exists already + n_lines_pre = len(kwargs['ax'].lines) if 'ax' in kwargs and kwargs['ax'] is not None else 0 + + # Create the plot + func(*args, **kwargs) + + # Get plot axis, if a specific one was provided, or if not, grab the current axis + cur_ax = kwargs['ax'] if 'ax' in kwargs and kwargs['ax'] is not None else plt.gca() + + # Check how many lines were added to the plot, and make info available to plot styling + n_lines_apply = len(cur_ax.lines) - n_lines_pre + style_kwargs['n_lines_apply'] = n_lines_apply + + # Apply the styling function + style_func(cur_ax, **style_kwargs) + + return decorated diff --git a/fooof/plts/utils.py b/fooof/plts/utils.py index ef5b53901..0a970ee91 100644 --- a/fooof/plts/utils.py +++ b/fooof/plts/utils.py @@ -8,9 +8,11 @@ from itertools import repeat from collections.abc import Iterator +from functools import wraps import numpy as np +from fooof.core.io import fname, fpath from fooof.core.modutils import safe_import from fooof.core.utils import resolve_aliases from fooof.plts.settings import PLT_ALPHA_LEVELS, PLT_ALIASES @@ -171,3 +173,23 @@ def check_plot_kwargs(plot_kwargs, defaults): plot_kwargs[key] = value return plot_kwargs + + +def savefig(func): + """Decorator function to save out figures.""" + + @wraps(func) + def decorated(*args, **kwargs): + + save_fig = kwargs.pop('save_fig', False) + file_name = kwargs.pop('file_name', None) + file_path = kwargs.pop('file_path', None) + + func(*args, **kwargs) + + if save_fig: + if not file_name: + raise ValueError("Input 'file_name' is required to save out the plot.") + plt.savefig(fpath(file_path, fname(file_name, 'png'))) + + return decorated diff --git a/fooof/tests/conftest.py b/fooof/tests/conftest.py index 65a8b00a9..b943456ce 100644 --- a/fooof/tests/conftest.py +++ b/fooof/tests/conftest.py @@ -9,7 +9,8 @@ from fooof.core.modutils import safe_import from fooof.tests.tutils import get_tfm, get_tfg, get_tbands -from fooof.tests.settings import BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH +from fooof.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, + TEST_REPORTS_PATH, TEST_PLOTS_PATH) plt = safe_import('.pyplot', 'matplotlib') @@ -33,6 +34,7 @@ def check_dir(): os.mkdir(BASE_TEST_FILE_PATH) os.mkdir(TEST_DATA_PATH) os.mkdir(TEST_REPORTS_PATH) + os.mkdir(TEST_PLOTS_PATH) @pytest.fixture(scope='session') def tfm(): diff --git a/fooof/tests/plts/test_annotate.py b/fooof/tests/plts/test_annotate.py index c096612cc..84f3848df 100644 --- a/fooof/tests/plts/test_annotate.py +++ b/fooof/tests/plts/test_annotate.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.annotate import * @@ -12,11 +13,13 @@ @plot_test def test_plot_annotated_peak_search(tfm, skip_if_no_mpl): - plot_annotated_peak_search(tfm) + plot_annotated_peak_search(tfm, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_annotated_peak_search.png') @plot_test def test_plot_annotated_model(tfm, skip_if_no_mpl): # Make sure model has been fit & then plot annotated model tfm.fit() - plot_annotated_model(tfm) + plot_annotated_model(tfm, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_annotated_model.png') diff --git a/fooof/tests/plts/test_aperiodic.py b/fooof/tests/plts/test_aperiodic.py index 167907846..477b22053 100644 --- a/fooof/tests/plts/test_aperiodic.py +++ b/fooof/tests/plts/test_aperiodic.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.aperiodic import * @@ -21,7 +22,8 @@ def test_plot_aperiodic_params(skip_if_no_mpl): # Test for 'knee' mode: offset, knee exponent aps = np.array([[1, 100, 1], [0.5, 150, 0.5], [2, 200, 2]]) - plot_aperiodic_params(aps) + plot_aperiodic_params(aps, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_aperiodic_params.png') @plot_test def test_plot_aperiodic_fits(skip_if_no_mpl): @@ -36,4 +38,5 @@ def test_plot_aperiodic_fits(skip_if_no_mpl): # Test for 'knee' mode: offset, knee exponent aps = np.array([[1, 100, 1], [0.5, 150, 0.5], [2, 200, 2]]) - plot_aperiodic_fits(aps, [1, 50]) + plot_aperiodic_fits(aps, [1, 50], save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_aperiodic_fits.png') diff --git a/fooof/tests/plts/test_error.py b/fooof/tests/plts/test_error.py index 2bffbc7a2..3e8b817bd 100644 --- a/fooof/tests/plts/test_error.py +++ b/fooof/tests/plts/test_error.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.error import * @@ -15,4 +16,5 @@ def test_plot_spectral_error(skip_if_no_mpl): fs = np.arange(3, 41, 1) errs = np.ones(len(fs)) - plot_spectral_error(fs, errs) + plot_spectral_error(fs, errs, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectral_error.png') diff --git a/fooof/tests/plts/test_fg.py b/fooof/tests/plts/test_fg.py index 8841c9fba..24103a916 100644 --- a/fooof/tests/plts/test_fg.py +++ b/fooof/tests/plts/test_fg.py @@ -6,6 +6,7 @@ from fooof.core.errors import NoModelError from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.fg import * @@ -15,7 +16,8 @@ @plot_test def test_plot_fg(tfg, skip_if_no_mpl): - plot_fg(tfg) + plot_fg(tfg, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fg.png') # Test error if no data available to plot tfg = FOOOFGroup() @@ -25,14 +27,17 @@ def test_plot_fg(tfg, skip_if_no_mpl): @plot_test def test_plot_fg_ap(tfg, skip_if_no_mpl): - plot_fg_ap(tfg) + plot_fg_ap(tfg, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fg_ap.png') @plot_test def test_plot_fg_gf(tfg, skip_if_no_mpl): - plot_fg_gf(tfg) + plot_fg_gf(tfg, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fg_gf.png') @plot_test def test_plot_fg_peak_cens(tfg, skip_if_no_mpl): - plot_fg_peak_cens(tfg) + plot_fg_peak_cens(tfg, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fg_peak_cens.png') diff --git a/fooof/tests/plts/test_fm.py b/fooof/tests/plts/test_fm.py index 7fce65ad5..6d7a0f02f 100644 --- a/fooof/tests/plts/test_fm.py +++ b/fooof/tests/plts/test_fm.py @@ -1,6 +1,7 @@ """Tests for fooof.plts.fm.""" from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.fm import * @@ -13,7 +14,8 @@ def test_plot_fm(tfm, skip_if_no_mpl): # Make sure model has been fit tfm.fit() - plot_fm(tfm) + plot_fm(tfm, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fm.png') @plot_test def test_plot_fm_add_peaks(tfm, skip_if_no_mpl): @@ -22,9 +24,7 @@ def test_plot_fm_add_peaks(tfm, skip_if_no_mpl): tfm.fit() # Test run each of the add peak approaches - for add_peak in ['shade', 'dot', 'outline', 'line']: - plot_fm(tfm, plot_peaks=add_peak) - - # Test run some combinations - for add_peak in ['shade-dot', 'outline-line']: - plot_fm(tfm, plot_peaks=add_peak) + for add_peak in ['shade', 'dot', 'outline', 'line', 'shade-dot', 'outline-line']: + file_name = 'test_plot_fm_add_peaks_' + add_peak + '.png' + plot_fm(tfm, plot_peaks=add_peak, save_fig=True, + file_path=TEST_PLOTS_PATH, file_name=file_name) diff --git a/fooof/tests/plts/test_periodic.py b/fooof/tests/plts/test_periodic.py index 647c967bd..83e77daf9 100644 --- a/fooof/tests/plts/test_periodic.py +++ b/fooof/tests/plts/test_periodic.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.periodic import * @@ -18,7 +19,8 @@ def test_plot_peak_params(skip_if_no_mpl): plot_peak_params(peaks) # Test with multiple set of params - plot_peak_params([peaks, peaks]) + plot_peak_params([peaks, peaks], save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_peak_params.png') @plot_test def test_plot_peak_fits(skip_if_no_mpl): @@ -29,4 +31,5 @@ def test_plot_peak_fits(skip_if_no_mpl): plot_peak_fits(peaks) # Test with multiple set of params - plot_peak_fits([peaks, peaks]) + plot_peak_fits([peaks, peaks], save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_peak_fits.png') diff --git a/fooof/tests/plts/test_spectra.py b/fooof/tests/plts/test_spectra.py index 9abb95b2a..0b85e2d9e 100644 --- a/fooof/tests/plts/test_spectra.py +++ b/fooof/tests/plts/test_spectra.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.spectra import * @@ -15,36 +16,46 @@ def test_plot_spectrum(tfm, skip_if_no_mpl): plot_spectrum(tfm.freqs, tfm.power_spectrum) # Test with logging both axes - plot_spectrum(tfm.freqs, tfm.power_spectrum, True, True) + plot_spectrum(tfm.freqs, tfm.power_spectrum, True, True, save_fig=True, + file_path=TEST_PLOTS_PATH, file_name='test_plot_spectrum.png') @plot_test def test_plot_spectra(tfg, skip_if_no_mpl): # Test with 1d inputs - 1d freq array and list of 1d power spectra - plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]]) + plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], + save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_1d.png') # Test with multiple freq inputs - list of 1d freq array and list of 1d power spectra - plot_spectra([tfg.freqs, tfg.freqs], [tfg.power_spectra[0, :], tfg.power_spectra[1, :]]) + plot_spectra([tfg.freqs, tfg.freqs], [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_list_of_1d.png') # Test with 2d array inputs plot_spectra(np.vstack([tfg.freqs, tfg.freqs]), - np.vstack([tfg.power_spectra[0, :], tfg.power_spectra[1, :]])) + np.vstack([tfg.power_spectra[0, :], tfg.power_spectra[1, :]]), + save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_2d.png') # Test with labels - plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], labels=['A', 'B']) + plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], labels=['A', 'B'], + save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_labels.png') @plot_test def test_plot_spectrum_shading(tfm, skip_if_no_mpl): - plot_spectrum_shading(tfm.freqs, tfm.power_spectrum, shades=[8, 12], add_center=True) + plot_spectrum_shading(tfm.freqs, tfm.power_spectrum, shades=[8, 12], add_center=True, + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectrum_shading.png') @plot_test def test_plot_spectra_shading(tfg, skip_if_no_mpl): plot_spectra_shading(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], - shades=[8, 12], add_center=True) + shades=[8, 12], add_center=True, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_shading.png') # Test with **kwargs that pass into plot_spectra plot_spectra_shading(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], shades=[8, 12], add_center=True, log_freqs=True, log_powers=True, - labels=['A', 'B']) + labels=['A', 'B'], save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_shading_kwargs.png') diff --git a/fooof/tests/plts/test_styles.py b/fooof/tests/plts/test_styles.py index f1b2f09af..72854ff97 100644 --- a/fooof/tests/plts/test_styles.py +++ b/fooof/tests/plts/test_styles.py @@ -1,20 +1,11 @@ """Tests for fooof.plts.styles.""" +from fooof.tests.tutils import plot_test from fooof.plts.style import * ################################################################################################### ################################################################################################### -def test_check_n_style(skip_if_no_mpl): - - # Check can pass None and do nothing - check_n_style(None) - - # Check can pass a callable - def checker(*args): - return True - check_n_style(checker) - def test_style_spectrum_plot(skip_if_no_mpl): # Create a dummy plot and style it @@ -26,3 +17,88 @@ def test_style_spectrum_plot(skip_if_no_mpl): # Check that axis labels are added - use as proxy that it ran correctly assert ax.xaxis.get_label().get_text() assert ax.yaxis.get_label().get_text() + + +def test_apply_axis_style(): + + _, ax = plt.subplots() + + title = 'Ploty McPlotface' + xlim = (1.0, 10.0) + ylabel = 'Line Value' + + apply_axis_style(ax, title=title, xlim=xlim, ylabel=ylabel) + + assert ax.get_title() == title + assert ax.get_xlim() == xlim + assert ax.get_ylabel() == ylabel + + +def test_apply_line_style(): + + # Check applying style to one line + _, ax = plt.subplots() + ax.plot([1, 2], [3, 4]) + + lw = 4 + apply_line_style(ax, lw=lw) + + assert ax.get_lines()[0].get_lw() == lw + + # Check applying style across multiple lines + _, ax = plt.subplots() + ax.plot([1, 2], [[3, 4], [5, 6]]) + + alphas = [0.5, 0.75] + apply_line_style(ax, alpha=alphas) + + for line, alpha in zip(ax.get_lines(), alphas): + assert line.get_alpha() == alpha + + +def test_apply_custom_style(): + + _, ax = plt.subplots() + ax.set_title('placeholder') + + # Test simple application of custom plot style + apply_custom_style(ax) + assert ax.title.get_size() == TITLE_FONTSIZE + + # Test adding input parameters to custom plot style + new_title_fontsize = 15.0 + apply_custom_style(ax, title_fontsize=new_title_fontsize) + assert ax.title.get_size() == new_title_fontsize + + +def test_apply_style(): + + _, ax = plt.subplots() + + def my_custom_styler(ax, **kwargs): + ax.set_title('DATA!') + + # Apply plot style using all defaults + apply_style(ax) + + # Apply plot style passing in a styler + apply_style(ax, custom_styler=my_custom_styler) + + +@plot_test +def test_style_plot(): + + @style_plot + def example_plot(): + plt.plot([1, 2], [3, 4]) + + def my_plot_style(ax, **kwargs): + ax.set_title('Custom!') + + # Test with applying default custom styling + lw = 5 + title = 'Science.' + example_plot(title=title, lw=lw) + + # Test with passing in own plot_style function + example_plot(plot_style=my_plot_style) diff --git a/fooof/tests/plts/test_utils.py b/fooof/tests/plts/test_utils.py index e7b2c5abb..0c4dce7f2 100644 --- a/fooof/tests/plts/test_utils.py +++ b/fooof/tests/plts/test_utils.py @@ -1,5 +1,8 @@ """Tests for fooof.plts.utils.""" +import os +import tempfile + from fooof.tests.tutils import plot_test from fooof.core.modutils import safe_import @@ -69,3 +72,13 @@ def test_check_plot_kwargs(skip_if_no_mpl): assert len(plot_kwargs) == 2 assert plot_kwargs['alpha'] == 0.5 assert plot_kwargs['linewidth'] == 2 + +def test_savefig(): + + @savefig + def example_plot(): + plt.plot([1, 2], [3, 4]) + + with tempfile.NamedTemporaryFile(mode='w+') as file: + example_plot(save_fig=True, file_name=file.name) + assert os.path.exists(file.name) diff --git a/fooof/tests/settings.py b/fooof/tests/settings.py index 9beae1afa..856c532f6 100644 --- a/fooof/tests/settings.py +++ b/fooof/tests/settings.py @@ -10,3 +10,4 @@ BASE_TEST_FILE_PATH = pkg.resource_filename(__name__, 'test_files') TEST_DATA_PATH = os.path.join(BASE_TEST_FILE_PATH, 'data') TEST_REPORTS_PATH = os.path.join(BASE_TEST_FILE_PATH, 'reports') +TEST_PLOTS_PATH = os.path.join(BASE_TEST_FILE_PATH, 'plots')