diff --git a/hnn_core/gui/_viz_manager.py b/hnn_core/gui/_viz_manager.py index 41eff684a..1ddfcdd3d 100644 --- a/hnn_core/gui/_viz_manager.py +++ b/hnn_core/gui/_viz_manager.py @@ -526,10 +526,12 @@ def _clear_axis(b, widgets, data, fig_idx, fig, ax, widgets_plot_type, existing_plots, add_plot_button): ax.clear() - # Remove "plot_spikes_hist"'s inverted second axes object, if exists - for axis in fig.axes: - if axis._label == "Inverted second axis": - axis.remove() + # Remove "plot_spikes_hist"'s inverted second axes object, if exists, and + # if the axis you are clearing is the spike histogram + if ax._label == "Spike histogram": + for axis in fig.axes: + if axis._label == "Inverted spike histogram": + axis.remove() # remove attached colorbar if exists if hasattr(fig, f'_cbar-ax-{id(ax)}'): diff --git a/hnn_core/tests/test_gui.py b/hnn_core/tests/test_gui.py index d23fbb440..123b72d81 100644 --- a/hnn_core/tests/test_gui.py +++ b/hnn_core/tests/test_gui.py @@ -666,6 +666,11 @@ def test_gui_visualization(setup_gui): for viz_type in _plot_types: gui._simulate_viz_action("edit_figure", figname, axname, 'default', viz_type, {}, 'clear') + # Check that extra axes have been successfully removed + assert len(gui.viz_manager.figs[figid].axes) == 1 + # Check if data on the axes has been successfully cleared + assert not gui.viz_manager.figs[figid].axes[0].has_data() + gui._simulate_viz_action("edit_figure", figname, axname, 'default', viz_type, {}, 'plot') # Check if data is plotted on the axes diff --git a/hnn_core/viz.py b/hnn_core/viz.py index c789d9d79..19eb7e9a4 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -538,7 +538,6 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, ax1.hist(plot_data, bins, label=spike_label, color=hist_color, **kwargs_hist) # Need to add label for easy removal later - ax1.set_label("Inverted second axis") # Set the y-limits based on the maximum across both axes if ax1 is not None: @@ -549,6 +548,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, ax.set_ylim(0, y_max) ax1.set_ylim(0, y_max) ax1.invert_yaxis() + ax1.set_label("Inverted spike histogram") if len(cell_response.times) > 0: ax.set_xlim(left=0, right=cell_response.times[-1]) @@ -556,6 +556,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, ax.set_xlim(left=0) ax.set_ylabel("Counts") + ax.set_label("Spike histogram") if ax1 is not None: # Combine legends