From 4e7cf8bf18a28c7b87fd9f59a2656b9e0a700dd8 Mon Sep 17 00:00:00 2001 From: David Stansby Date: Fri, 5 May 2023 11:15:36 +0100 Subject: [PATCH] Simplify logic for adding a single axes --- src/napari_matplotlib/base.py | 45 +++++++++++++++++++----------- src/napari_matplotlib/histogram.py | 3 +- src/napari_matplotlib/scatter.py | 3 +- src/napari_matplotlib/slice.py | 2 +- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index 9267d75c..a980a834 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -3,10 +3,12 @@ from typing import List, Tuple import napari +from matplotlib.axes import Axes from matplotlib.backends.backend_qt5agg import ( FigureCanvas, NavigationToolbar2QT, ) +from matplotlib.figure import Figure from qtpy.QtGui import QIcon from qtpy.QtWidgets import QVBoxLayout, QWidget @@ -23,6 +25,8 @@ class NapariMPLWidget(QWidget): Base Matplotlib canvas. Widget that can be embedded as a napari widget. This creates a single FigureCanvas, which contains a single Figure. + It is not responsible for creating any Axes, because different widgets + may want to implement different subplot layouts. This class also handles callbacks to automatically update figures when the layer selection or z-step is changed in the napari viewer. To take @@ -33,8 +37,6 @@ class NapariMPLWidget(QWidget): ---------- viewer : `napari.Viewer` Main napari viewer. - figure : `matplotlib.figure.Figure` - Matplotlib figure. canvas : matplotlib.backends.backend_qt5agg.FigureCanvas Matplotlib canvas. layers : `list` @@ -64,6 +66,11 @@ def __init__(self, napari_viewer: napari.viewer.Viewer): # Accept any type of input layer by default input_layer_types: Tuple[napari.layers.Layer, ...] = (napari.layers.Layer,) + @property + def figure(self) -> Figure: + """Matplotlib figure.""" + return self.canvas.figure + @property def n_selected_layers(self) -> int: """ @@ -125,25 +132,31 @@ def draw(self) -> None: This is a no-op, and is intended for derived classes to override. """ - def apply_napari_colorscheme(self) -> None: - """Apply napari-compatible colorscheme to the axes object.""" - if self.axes is None: - return + def add_single_axes(self) -> None: + """ + Add a single Axes to the figure. + + The Axes is saved on the ``.axes`` attribute for later access. + """ + self.axes = self.figure.subplots() + self.apply_napari_colorscheme(self.axes) + + @staticmethod + def apply_napari_colorscheme(ax: Axes) -> None: + """Apply napari-compatible colorscheme to an axes object.""" # changing color of axes background to transparent - self.canvas.figure.patch.set_facecolor("none") - self.axes.set_facecolor("none") + ax.set_facecolor("none") # changing colors of all axes - [ - self.axes.spines[spine].set_color("white") - for spine in self.axes.spines - ] - self.axes.xaxis.label.set_color("white") - self.axes.yaxis.label.set_color("white") + for spine in ax.spines: + ax.spines[spine].set_color("white") + + ax.xaxis.label.set_color("white") + ax.yaxis.label.set_color("white") # changing colors of axes labels - self.axes.tick_params(axis="x", colors="white") - self.axes.tick_params(axis="y", colors="white") + ax.tick_params(axis="x", colors="white") + ax.tick_params(axis="y", colors="white") def _on_update_layers(self) -> None: """ diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index 0dab0bdf..7e863826 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -21,8 +21,7 @@ class HistogramWidget(NapariMPLWidget): def __init__(self, napari_viewer: napari.viewer.Viewer): super().__init__(napari_viewer) - self.axes = self.canvas.figure.subplots() - self.apply_napari_colorscheme() + self.add_single_axes() self.update_layers(None) def clear(self) -> None: diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 71b762bb..6c0d3d90 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -31,8 +31,7 @@ class ScatterBaseWidget(NapariMPLWidget): def __init__(self, napari_viewer: napari.viewer.Viewer): super().__init__(napari_viewer) - self.axes = self.canvas.figure.subplots() - self.apply_napari_colorscheme() + self.add_single_axes() self.update_layers(None) def clear(self) -> None: diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py index d6788706..bd8d219a 100644 --- a/src/napari_matplotlib/slice.py +++ b/src/napari_matplotlib/slice.py @@ -24,7 +24,7 @@ class SliceWidget(NapariMPLWidget): def __init__(self, napari_viewer: napari.viewer.Viewer): # Setup figure/axes super().__init__(napari_viewer) - self.axes = self.canvas.figure.subplots() + self.add_single_axes() button_layout = QHBoxLayout() self.layout().addLayout(button_layout)