1
1
from typing import Any , List , Optional , Tuple
2
2
3
- import matplotlib .colors as mcolor
4
3
import napari
5
4
import numpy .typing as npt
6
5
from magicgui import magicgui
@@ -17,15 +16,8 @@ class ScatterBaseWidget(NapariMPLWidget):
17
16
Base class for widgets that scatter two datasets against each other.
18
17
"""
19
18
20
- # opacity value for the markers
21
- _marker_alpha = 0.5
22
-
23
- # flag set to True if histogram should be used
24
- # for plotting large points
25
- _histogram_for_large_data = True
26
-
27
19
# if the number of points is greater than this value,
28
- # the scatter is plotted as a 2dhist
20
+ # the scatter is plotted as a 2D histogram
29
21
_threshold_to_switch_to_histogram = 500
30
22
31
23
def __init__ (self , napari_viewer : napari .viewer .Viewer ):
@@ -44,40 +36,32 @@ def draw(self) -> None:
44
36
"""
45
37
Scatter the currently selected layers.
46
38
"""
47
- data , x_axis_name , y_axis_name = self ._get_data ()
48
-
49
- if len (data ) == 0 :
50
- # don't plot if there isn't data
51
- return
39
+ x , y , x_axis_name , y_axis_name = self ._get_data ()
52
40
53
- if self ._histogram_for_large_data and (
54
- data [0 ].size > self ._threshold_to_switch_to_histogram
55
- ):
41
+ if x .size > self ._threshold_to_switch_to_histogram :
56
42
self .axes .hist2d (
57
- data [ 0 ] .ravel (),
58
- data [ 1 ] .ravel (),
43
+ x .ravel (),
44
+ y .ravel (),
59
45
bins = 100 ,
60
- norm = mcolor .LogNorm (),
61
46
)
62
47
else :
63
- self .axes .scatter (data [ 0 ], data [ 1 ] , alpha = self . _marker_alpha )
48
+ self .axes .scatter (x , y , alpha = 0.5 )
64
49
65
50
self .axes .set_xlabel (x_axis_name )
66
51
self .axes .set_ylabel (y_axis_name )
67
52
68
- def _get_data (self ) -> Tuple [List [npt .NDArray [Any ]], str , str ]:
69
- """Get the plot data.
53
+ def _get_data (self ) -> Tuple [npt .NDArray [Any ], npt .NDArray [Any ], str , str ]:
54
+ """
55
+ Get the plot data.
70
56
71
57
This must be implemented on the subclass.
72
58
73
59
Returns
74
60
-------
75
- data : np.ndarray
76
- The list containing the scatter plot data.
77
- x_axis_name : str
78
- The label to display on the x axis
79
- y_axis_name: str
80
- The label to display on the y axis
61
+ x, y : np.ndarray
62
+ x and y values of plot data.
63
+ x_axis_name, y_axis_name : str
64
+ Label to display on the x/y axis
81
65
"""
82
66
raise NotImplementedError
83
67
@@ -93,7 +77,7 @@ class ScatterWidget(ScatterBaseWidget):
93
77
n_layers_input = Interval (2 , 2 )
94
78
input_layer_types = (napari .layers .Image ,)
95
79
96
- def _get_data (self ) -> Tuple [List [ npt .NDArray [Any ]], str , str ]:
80
+ def _get_data (self ) -> Tuple [npt .NDArray [Any ], npt . NDArray [ Any ], str , str ]:
97
81
"""
98
82
Get the plot data.
99
83
@@ -106,11 +90,12 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
106
90
y_axis_name: str
107
91
The title to display on the y axis
108
92
"""
109
- data = [layer .data [self .current_z ] for layer in self .layers ]
93
+ x = self .layers [0 ].data [self .current_z ]
94
+ y = self .layers [1 ].data [self .current_z ]
110
95
x_axis_name = self .layers [0 ].name
111
96
y_axis_name = self .layers [1 ].name
112
97
113
- return data , x_axis_name , y_axis_name
98
+ return x , y , x_axis_name , y_axis_name
114
99
115
100
116
101
class FeaturesScatterWidget (ScatterBaseWidget ):
@@ -191,9 +176,33 @@ def _get_valid_axis_keys(
191
176
else :
192
177
return self .layers [0 ].features .keys ()
193
178
194
- def _get_data (self ) -> Tuple [ List [ npt . NDArray [ Any ]], str , str ] :
179
+ def _ready_to_scatter (self ) -> bool :
195
180
"""
196
- Get the plot data.
181
+ Return True if selected layer has a feature table we can scatter with,
182
+ and the two columns to be scatterd have been selected.
183
+ """
184
+ if not hasattr (self .layers [0 ], "features" ):
185
+ return False
186
+
187
+ feature_table = self .layers [0 ].features
188
+ return (
189
+ feature_table is not None
190
+ and len (feature_table ) > 0
191
+ and self .x_axis_key is not None
192
+ and self .y_axis_key is not None
193
+ )
194
+
195
+ def draw (self ) -> None :
196
+ """
197
+ Scatter two features from the currently selected layer.
198
+ """
199
+ if self ._ready_to_scatter ():
200
+ super ().draw ()
201
+
202
+ def _get_data (self ) -> Tuple [npt .NDArray [Any ], npt .NDArray [Any ], str , str ]:
203
+ """
204
+ Get the plot data from the ``features`` attribute of the first
205
+ selected layer.
197
206
198
207
Returns
199
208
-------
@@ -207,28 +216,15 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
207
216
The title to display on the y axis. Returns
208
217
an empty string if nothing to plot.
209
218
"""
210
- if not hasattr (self .layers [0 ], "features" ):
211
- # if the selected layer doesn't have a featuretable,
212
- # skip draw
213
- return [], "" , ""
214
-
215
219
feature_table = self .layers [0 ].features
216
220
217
- if (
218
- (len (feature_table ) == 0 )
219
- or (self .x_axis_key is None )
220
- or (self .y_axis_key is None )
221
- ):
222
- return [], "" , ""
223
-
224
- data_x = feature_table [self .x_axis_key ]
225
- data_y = feature_table [self .y_axis_key ]
226
- data = [data_x , data_y ]
221
+ x = feature_table [self .x_axis_key ]
222
+ y = feature_table [self .y_axis_key ]
227
223
228
- x_axis_name = self .x_axis_key . replace ( "_" , " " )
229
- y_axis_name = self .y_axis_key . replace ( "_" , " " )
224
+ x_axis_name = str ( self .x_axis_key )
225
+ y_axis_name = str ( self .y_axis_key )
230
226
231
- return data , x_axis_name , y_axis_name
227
+ return x , y , x_axis_name , y_axis_name
232
228
233
229
def _on_update_layers (self ) -> None :
234
230
"""
0 commit comments