Skip to content

Commit cbb42d1

Browse files
committed
Simplify scatter code
1 parent 58223ca commit cbb42d1

File tree

3 files changed

+66
-69
lines changed

3 files changed

+66
-69
lines changed

Diff for: CHANGELOG.rst

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
0.0.2
1+
0.4.0
22
=====
33

4-
New features
5-
------------
6-
- `HistogramWidget` now shows individual histograms for RGB channels when
7-
present.
8-
9-
10-
Bug fixes
11-
---------
12-
- `HistogramWidget` now works properly with 2D images.
4+
Changes
5+
-------
6+
- The scatter widgets no longer use a LogNorm() for 2D histogram scaling.
7+
This is to move the widget in line with the philosophy of using Matplotlib default
8+
settings throughout ``napari-matplotlib``. This still leaves open the option of
9+
adding the option to change the normalization in the future. If this is something
10+
you would be interested in please open an issue at https://github.com/matplotlib/napari-matplotlib.
11+
- Labels plotting with the features scatter widget no longer have underscores
12+
replaced with spaces.

Diff for: src/napari_matplotlib/scatter.py

+48-52
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, List, Optional, Tuple
22

3-
import matplotlib.colors as mcolor
43
import napari
54
import numpy.typing as npt
65
from magicgui import magicgui
@@ -17,15 +16,8 @@ class ScatterBaseWidget(NapariMPLWidget):
1716
Base class for widgets that scatter two datasets against each other.
1817
"""
1918

20-
# opacity value for the markers
21-
_marker_alpha = 0.5
22-
23-
# flag set to True if histogram should be used
24-
# for plotting large points
25-
_histogram_for_large_data = True
26-
2719
# if the number of points is greater than this value,
28-
# the scatter is plotted as a 2dhist
20+
# the scatter is plotted as a 2D histogram
2921
_threshold_to_switch_to_histogram = 500
3022

3123
def __init__(self, napari_viewer: napari.viewer.Viewer):
@@ -44,40 +36,32 @@ def draw(self) -> None:
4436
"""
4537
Scatter the currently selected layers.
4638
"""
47-
data, x_axis_name, y_axis_name = self._get_data()
48-
49-
if len(data) == 0:
50-
# don't plot if there isn't data
51-
return
39+
x, y, x_axis_name, y_axis_name = self._get_data()
5240

53-
if self._histogram_for_large_data and (
54-
data[0].size > self._threshold_to_switch_to_histogram
55-
):
41+
if x.size > self._threshold_to_switch_to_histogram:
5642
self.axes.hist2d(
57-
data[0].ravel(),
58-
data[1].ravel(),
43+
x.ravel(),
44+
y.ravel(),
5945
bins=100,
60-
norm=mcolor.LogNorm(),
6146
)
6247
else:
63-
self.axes.scatter(data[0], data[1], alpha=self._marker_alpha)
48+
self.axes.scatter(x, y, alpha=0.5)
6449

6550
self.axes.set_xlabel(x_axis_name)
6651
self.axes.set_ylabel(y_axis_name)
6752

68-
def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
69-
"""Get the plot data.
53+
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
54+
"""
55+
Get the plot data.
7056
7157
This must be implemented on the subclass.
7258
7359
Returns
7460
-------
75-
data : np.ndarray
76-
The list containing the scatter plot data.
77-
x_axis_name : str
78-
The label to display on the x axis
79-
y_axis_name: str
80-
The label to display on the y axis
61+
x, y : np.ndarray
62+
x and y values of plot data.
63+
x_axis_name, y_axis_name : str
64+
Label to display on the x/y axis
8165
"""
8266
raise NotImplementedError
8367

@@ -93,7 +77,7 @@ class ScatterWidget(ScatterBaseWidget):
9377
n_layers_input = Interval(2, 2)
9478
input_layer_types = (napari.layers.Image,)
9579

96-
def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
80+
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
9781
"""
9882
Get the plot data.
9983
@@ -106,11 +90,12 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
10690
y_axis_name: str
10791
The title to display on the y axis
10892
"""
109-
data = [layer.data[self.current_z] for layer in self.layers]
93+
x = self.layers[0].data[self.current_z]
94+
y = self.layers[1].data[self.current_z]
11095
x_axis_name = self.layers[0].name
11196
y_axis_name = self.layers[1].name
11297

113-
return data, x_axis_name, y_axis_name
98+
return x, y, x_axis_name, y_axis_name
11499

115100

116101
class FeaturesScatterWidget(ScatterBaseWidget):
@@ -191,9 +176,33 @@ def _get_valid_axis_keys(
191176
else:
192177
return self.layers[0].features.keys()
193178

194-
def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
179+
def _ready_to_scatter(self) -> bool:
195180
"""
196-
Get the plot data.
181+
Return True if selected layer has a feature table we can scatter with,
182+
and the two columns to be scatterd have been selected.
183+
"""
184+
if not hasattr(self.layers[0], "features"):
185+
return False
186+
187+
feature_table = self.layers[0].features
188+
return (
189+
feature_table is not None
190+
and len(feature_table) > 0
191+
and self.x_axis_key is not None
192+
and self.y_axis_key is not None
193+
)
194+
195+
def draw(self) -> None:
196+
"""
197+
Scatter two features from the currently selected layer.
198+
"""
199+
if self._ready_to_scatter():
200+
super().draw()
201+
202+
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
203+
"""
204+
Get the plot data from the ``features`` attribute of the first
205+
selected layer.
197206
198207
Returns
199208
-------
@@ -207,28 +216,15 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
207216
The title to display on the y axis. Returns
208217
an empty string if nothing to plot.
209218
"""
210-
if not hasattr(self.layers[0], "features"):
211-
# if the selected layer doesn't have a featuretable,
212-
# skip draw
213-
return [], "", ""
214-
215219
feature_table = self.layers[0].features
216220

217-
if (
218-
(len(feature_table) == 0)
219-
or (self.x_axis_key is None)
220-
or (self.y_axis_key is None)
221-
):
222-
return [], "", ""
223-
224-
data_x = feature_table[self.x_axis_key]
225-
data_y = feature_table[self.y_axis_key]
226-
data = [data_x, data_y]
221+
x = feature_table[self.x_axis_key]
222+
y = feature_table[self.y_axis_key]
227223

228-
x_axis_name = self.x_axis_key.replace("_", " ")
229-
y_axis_name = self.y_axis_key.replace("_", " ")
224+
x_axis_name = str(self.x_axis_key)
225+
y_axis_name = str(self.y_axis_key)
230226

231-
return data, x_axis_name, y_axis_name
227+
return x, y, x_axis_name, y_axis_name
232228

233229
def _on_update_layers(self) -> None:
234230
"""

Diff for: src/napari_matplotlib/tests/test_scatter.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def make_labels_layer_with_features() -> (
3939

4040

4141
def test_features_scatter_get_data(make_napari_viewer):
42-
"""Test the get data method"""
42+
"""
43+
Test the get data method.
44+
"""
4345
# make the label image
4446
label_image, feature_table = make_labels_layer_with_features()
4547

@@ -55,17 +57,16 @@ def test_features_scatter_get_data(make_napari_viewer):
5557
y_column = "feature_2"
5658
scatter_widget.y_axis_key = y_column
5759

58-
data, x_axis_name, y_axis_name = scatter_widget._get_data()
59-
np.testing.assert_allclose(
60-
data, np.stack((feature_table[x_column], feature_table[y_column]))
61-
)
60+
x, y, x_axis_name, y_axis_name = scatter_widget._get_data()
61+
np.testing.assert_allclose(x, feature_table[x_column])
62+
np.testing.assert_allclose(y, np.stack(feature_table[y_column]))
6263
assert x_axis_name == x_column.replace("_", " ")
6364
assert y_axis_name == y_column.replace("_", " ")
6465

6566

6667
def test_get_valid_axis_keys(make_napari_viewer):
67-
"""Test the values returned from
68-
FeaturesScatterWidget._get_valid_keys() when there
68+
"""
69+
Test the values returned from _get_valid_keys() when there
6970
are valid keys.
7071
"""
7172
# make the label image

0 commit comments

Comments
 (0)