diff --git a/neurodsp/plts/utils.py b/neurodsp/plts/utils.py index 2f613d4d..5de0bcf8 100644 --- a/neurodsp/plts/utils.py +++ b/neurodsp/plts/utils.py @@ -42,10 +42,20 @@ def decorated(*args, **kwargs): # Defaults to saving when file name given (since bool(str)->True; bool(None)->False) save_fig = kwargs.pop('save_fig', bool(file_name)) + # Check any collect any other plot keywords + save_kwargs = kwargs.pop('save_kwargs', {}) + save_kwargs.setdefault('bbox_inches', 'tight') + + # Check and collect whether to close the plot + close = kwargs.pop('close', None) + func(*args, **kwargs) if save_fig: full_path = pjoin(file_path, file_name) if file_path else file_name - plt.savefig(full_path) + plt.savefig(full_path, **save_kwargs) + + if close: + plt.close() return decorated diff --git a/neurodsp/tests/plts/test_utils.py b/neurodsp/tests/plts/test_utils.py index 0fb4d80d..733c960e 100644 --- a/neurodsp/tests/plts/test_utils.py +++ b/neurodsp/tests/plts/test_utils.py @@ -39,6 +39,11 @@ def example_plot(): example_plot(save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_savefig2.pdf') assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig2.pdf')) + # Test giving additional save kwargs + example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig3.pdf', + save_kwargs={'facecolor' : 'red'}) + assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig3.pdf')) + # Test does not save when `save_fig` set to False - example_plot(save_fig=False, file_path=TEST_PLOTS_PATH, file_name='test_savefig3.pdf') - assert not os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig3.pdf')) + example_plot(save_fig=False, file_path=TEST_PLOTS_PATH, file_name='test_savefig_nope.pdf') + assert not os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig_nope.pdf'))