diff --git a/examples/features_scatter.py b/examples/features_scatter.py new file mode 100644 index 00000000..ac8580d7 --- /dev/null +++ b/examples/features_scatter.py @@ -0,0 +1,36 @@ +import napari +import numpy as np +from skimage.measure import regionprops_table + +# make a test label image +label_image = np.zeros((100, 100), dtype=np.uint16) + +label_image[10:20, 10:20] = 1 +label_image[50:70, 50:70] = 2 + +feature_table_1 = regionprops_table( + label_image, properties=("label", "area", "perimeter") +) +feature_table_1["index"] = feature_table_1["label"] + +# make the points data +n_points = 100 +points_data = 100 * np.random.random((100, 2)) +points_features = { + "feature_0": np.random.random((n_points,)), + "feature_1": np.random.random((n_points,)), + "feature_2": np.random.random((n_points,)), +} + +# create the viewer +viewer = napari.Viewer() +viewer.add_labels(label_image, features=feature_table_1) +viewer.add_points(points_data, features=points_features) + +# make the widget +viewer.window.add_plugin_dock_widget( + plugin_name="napari-matplotlib", widget_name="FeaturesScatter" +) + +if __name__ == "__main__": + napari.run() diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index 52293e4d..00f4292f 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -87,13 +87,14 @@ def setup_callbacks(self) -> None: # z-step changed in viewer self.viewer.dims.events.current_step.connect(self._draw) # Layer selection changed in viewer - self.viewer.layers.selection.events.active.connect(self.update_layers) + self.viewer.layers.selection.events.changed.connect(self.update_layers) def update_layers(self, event: napari.utils.events.Event) -> None: """ Update the layers attribute with currently selected layers and re-draw. """ self.layers = list(self.viewer.layers.selection) + self._on_update_layers() self._draw() def _draw(self) -> None: @@ -103,6 +104,7 @@ def _draw(self) -> None: """ self.clear() if self.n_selected_layers != self.n_layers_input: + self.canvas.draw() return self.draw() self.canvas.draw() @@ -120,6 +122,14 @@ def draw(self) -> None: This is a no-op, and is intended for derived classes to override. """ + + + def _on_update_layers(self) -> None: + """This function is called when self.layers is updated via self.update_layers() + + This is a no-op, and is intended for derived classes to override. + """ + def _replace_toolbar_icons(self): # Modify toolbar icons and some tooltips for action in self.toolbar.actions(): diff --git a/src/napari_matplotlib/napari.yaml b/src/napari_matplotlib/napari.yaml index cd585879..b736592b 100644 --- a/src/napari_matplotlib/napari.yaml +++ b/src/napari_matplotlib/napari.yaml @@ -8,7 +8,11 @@ contributions: - id: napari-matplotlib.scatter python_name: napari_matplotlib:ScatterWidget - title: Make a scatter plot + title: Make a scatter plot of image intensities + + - id: napari-matplotlib.features_scatter + python_name: napari_matplotlib:FeaturesScatterWidget + title: Make a scatter plot of layer features - id: napari-matplotlib.slice python_name: napari_matplotlib:SliceWidget @@ -21,5 +25,8 @@ contributions: - command: napari-matplotlib.scatter display_name: Scatter + - command: napari-matplotlib.features_scatter + display_name: FeaturesScatter + - command: napari-matplotlib.slice display_name: 1D slice diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index c3b12742..324e9126 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,39 +1,225 @@ +from typing import List, Tuple, Union + import matplotlib.colors as mcolor import napari +import numpy as np +from magicgui import magicgui from .base import NapariMPLWidget -__all__ = ["ScatterWidget"] +__all__ = ["ScatterWidget", "FeaturesScatterWidget"] -class ScatterWidget(NapariMPLWidget): - """ - Widget to display scatter plot of two similarly shaped layers. +class ScatterBaseWidget(NapariMPLWidget): + # opacity value for the markers + _marker_alpha = 0.5 - If there are more than 500 data points, a 2D histogram is displayed instead - of a scatter plot, to avoid too many scatter points. - """ + # flag set to True if histogram should be used + # for plotting large points + _histogram_for_large_data = True - n_layers_input = 2 + # if the number of points is greater than this value, + # the scatter is plotted as a 2dhist + _threshold_to_switch_to_histogram = 500 - def __init__(self, napari_viewer: napari.viewer.Viewer): + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + ): super().__init__(napari_viewer) + self.axes = self.canvas.figure.subplots() self.update_layers(None) + def clear(self) -> None: + self.axes.clear() + def draw(self) -> None: """ Clear the axes and scatter the currently selected layers. """ - data = [layer.data[self.current_z] for layer in self.layers] - if data[0].size < 500: - self.axes.scatter(data[0], data[1], alpha=0.5) - else: + data, x_axis_name, y_axis_name = self._get_data() + + if len(data) == 0: + # don't plot if there isn't data + return + + if self._histogram_for_large_data and ( + data[0].size > self._threshold_to_switch_to_histogram + ): self.axes.hist2d( data[0].ravel(), data[1].ravel(), bins=100, norm=mcolor.LogNorm(), ) - self.axes.set_xlabel(self.layers[0].name) - self.axes.set_ylabel(self.layers[1].name) + else: + self.axes.scatter(data[0], data[1], alpha=self._marker_alpha) + + self.axes.set_xlabel(x_axis_name) + self.axes.set_ylabel(y_axis_name) + + def _get_data(self) -> Tuple[List[np.ndarray], str, str]: + """Get the plot data. + + This must be implemented on the subclass. + + Returns + ------- + data : np.ndarray + The list containing the scatter plot data. + x_axis_name : str + The label to display on the x axis + y_axis_name: str + The label to display on the y axis + """ + raise NotImplementedError + + +class ScatterWidget(ScatterBaseWidget): + """ + Widget to display scatter plot of two similarly shaped image layers. + + If there are more than 500 data points, a 2D histogram is displayed instead + of a scatter plot, to avoid too many scatter points. + """ + + n_layers_input = 2 + + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + ): + super().__init__( + napari_viewer, + ) + + def _get_data(self) -> Tuple[List[np.ndarray], str, str]: + """Get the plot data. + + Returns + ------- + data : List[np.ndarray] + List contains the in view slice of X and Y axis images. + x_axis_name : str + The title to display on the x axis + y_axis_name: str + The title to display on the y axis + """ + data = [layer.data[self.current_z] for layer in self.layers] + x_axis_name = self.layers[0].name + y_axis_name = self.layers[1].name + + return data, x_axis_name, y_axis_name + + +class FeaturesScatterWidget(ScatterBaseWidget): + n_layers_input = 1 + + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + key_selection_gui: bool = True, + ): + self._key_selection_widget = None + super().__init__( + napari_viewer, + ) + + if key_selection_gui is True: + self._key_selection_widget = magicgui( + self._set_axis_keys, + x_axis_key={"choices": self._get_valid_axis_keys}, + y_axis_key={"choices": self._get_valid_axis_keys}, + call_button="plot", + ) + self.layout().addWidget(self._key_selection_widget.native) + + @property + def x_axis_key(self) -> Union[None, str]: + """Key to access x axis data from the FeaturesTable""" + return self._x_axis_key + + @x_axis_key.setter + def x_axis_key(self, key: Union[None, str]): + self._x_axis_key = key + self._draw() + + @property + def y_axis_key(self) -> Union[None, str]: + """Key to access y axis data from the FeaturesTable""" + return self._y_axis_key + + @y_axis_key.setter + def y_axis_key(self, key: Union[None, str]): + self._y_axis_key = key + self._draw() + + def _set_axis_keys(self, x_axis_key: str, y_axis_key: str): + """Set both axis keys and then redraw the plot""" + self._x_axis_key = x_axis_key + self._y_axis_key = y_axis_key + self._draw() + + def _get_valid_axis_keys(self, combo_widget=None) -> List[str]: + """Get the valid axis keys from the layer FeatureTable. + + Returns + ------- + axis_keys : List[str] + The valid axis keys in the FeatureTable. If the table is empty + or there isn't a table, returns an empty list. + """ + if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): + return [] + else: + return self.layers[0].features.keys() + + def _get_data(self) -> Tuple[List[np.ndarray], str, str]: + """Get the plot data. + + Returns + ------- + data : List[np.ndarray] + List contains X and Y columns from the FeatureTable. Returns + an empty array if nothing to plot. + x_axis_name : str + The title to display on the x axis. Returns + an empty string if nothing to plot. + y_axis_name: str + The title to display on the y axis. Returns + an empty string if nothing to plot. + """ + if not hasattr(self.layers[0], "features"): + # if the selected layer doesn't have a featuretable, + # skip draw + return np.array([]), "", "" + + feature_table = self.layers[0].features + + if ( + (len(feature_table) == 0) + or (self.x_axis_key is None) + or (self.y_axis_key is None) + ): + return np.array([]), "", "" + + data_x = feature_table[self.x_axis_key] + data_y = feature_table[self.y_axis_key] + data = [data_x, data_y] + + x_axis_name = self.x_axis_key.replace("_", " ") + y_axis_name = self.y_axis_key.replace("_", " ") + + return data, x_axis_name, y_axis_name + + def _on_update_layers(self) -> None: + """This is called when the layer selection changes + by self.update_layers(). + """ + if self._key_selection_widget is not None: + self._key_selection_widget.reset_choices() + + # reset the axis keys + self._x_axis_key = None + self._y_axis_key = None diff --git a/src/napari_matplotlib/tests/test_scatter.py b/src/napari_matplotlib/tests/test_scatter.py index 75a6fda6..8103968e 100644 --- a/src/napari_matplotlib/tests/test_scatter.py +++ b/src/napari_matplotlib/tests/test_scatter.py @@ -1,11 +1,99 @@ import numpy as np -from napari_matplotlib import ScatterWidget +from napari_matplotlib import FeaturesScatterWidget, ScatterWidget def test_scatter(make_napari_viewer): - # Smoke test adding a histogram widget + # Smoke test adding a scatter widget viewer = make_napari_viewer() viewer.add_image(np.random.random((100, 100))) viewer.add_image(np.random.random((100, 100))) ScatterWidget(viewer) + + +def test_features_scatter_widget(make_napari_viewer): + # Smoke test adding a features scatter widget + viewer = make_napari_viewer() + viewer.add_image(np.random.random((100, 100))) + viewer.add_labels(np.random.randint(0, 5, (100, 100))) + FeaturesScatterWidget(viewer) + + +def make_labels_layer_with_features(): + label_image = np.zeros((100, 100), dtype=np.uint16) + for label_value, start_index in enumerate([10, 30, 50], start=1): + end_index = start_index + 10 + label_image[start_index:end_index, start_index:end_index] = label_value + feature_table = { + "index": [1, 2, 3], + "feature_0": np.random.random((3,)), + "feature_1": np.random.random((3,)), + "feature_2": np.random.random((3,)), + } + return label_image, feature_table + + +def test_features_scatter_get_data(make_napari_viewer): + """test the get data method""" + # make the label image + label_image, feature_table = make_labels_layer_with_features() + + viewer = make_napari_viewer() + labels_layer = viewer.add_labels(label_image, features=feature_table) + scatter_widget = FeaturesScatterWidget(viewer) + + # select the labels layer + viewer.layers.selection = [labels_layer] + + x_column = "feature_0" + scatter_widget.x_axis_key = x_column + y_column = "feature_2" + scatter_widget.y_axis_key = y_column + + data, x_axis_name, y_axis_name = scatter_widget._get_data() + np.testing.assert_allclose( + data, np.stack((feature_table[x_column], feature_table[y_column])) + ) + assert x_axis_name == x_column.replace("_", " ") + assert y_axis_name == y_column.replace("_", " ") + + +def test_get_valid_axis_keys(make_napari_viewer): + """test the values returned from + FeaturesScatterWidget._get_valid_keys() when there + are valid keys. + """ + # make the label image + label_image, feature_table = make_labels_layer_with_features() + + viewer = make_napari_viewer() + labels_layer = viewer.add_labels(label_image, features=feature_table) + scatter_widget = FeaturesScatterWidget(viewer) + + viewer.layers.selection = [labels_layer] + valid_keys = scatter_widget._get_valid_axis_keys() + assert set(valid_keys) == set(feature_table.keys()) + + +def test_get_valid_axis_keys_no_valid_keys(make_napari_viewer): + """test the values returned from + FeaturesScatterWidget._get_valid_keys() when there + are not valid keys. + """ + # make the label image + label_image, _ = make_labels_layer_with_features() + + viewer = make_napari_viewer() + labels_layer = viewer.add_labels(label_image) + image_layer = viewer.add_image(np.random.random((100, 100))) + scatter_widget = FeaturesScatterWidget(viewer) + + # no features in a label image + viewer.layers.selection = [labels_layer] + valid_keys = scatter_widget._get_valid_axis_keys() + assert set(valid_keys) == set() + + # image layer doesn't have features + viewer.layers.selection = [image_layer] + valid_keys = scatter_widget._get_valid_axis_keys() + assert set(valid_keys) == set()