forked from matplotlib/napari-matplotlib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_scatter.py
120 lines (93 loc) · 3.76 KB
/
test_scatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from copy import deepcopy
from typing import Any, Dict, Tuple
import numpy as np
import numpy.typing as npt
import pytest
from napari_matplotlib import FeaturesScatterWidget, ScatterWidget
@pytest.mark.mpl_image_compare
def test_scatter(make_napari_viewer, astronaut_data):
viewer = make_napari_viewer()
widget = ScatterWidget(viewer)
fig = widget.figure
viewer.add_image(astronaut_data[0], **astronaut_data[1], name="astronaut")
viewer.add_image(
astronaut_data[0] * -1, **astronaut_data[1], name="astronaut_reversed"
)
# De-select existing selection
viewer.layers.selection.clear()
# Select images
viewer.layers.selection.add(viewer.layers[0])
viewer.layers.selection.add(viewer.layers[1])
return deepcopy(fig)
def test_features_scatter_widget(make_napari_viewer):
# Smoke test adding a features scatter widget
viewer = make_napari_viewer()
viewer.add_image(np.random.random((100, 100)))
viewer.add_labels(np.random.randint(0, 5, (100, 100)))
FeaturesScatterWidget(viewer)
def make_labels_layer_with_features() -> (
Tuple[npt.NDArray[np.uint16], Dict[str, Any]]
):
label_image = np.zeros((100, 100), dtype=np.uint16)
for label_value, start_index in enumerate([10, 30, 50], start=1):
end_index = start_index + 10
label_image[start_index:end_index, start_index:end_index] = label_value
feature_table = {
"index": [1, 2, 3],
"feature_0": np.random.random((3,)),
"feature_1": np.random.random((3,)),
"feature_2": np.random.random((3,)),
}
return label_image, feature_table
def test_features_scatter_get_data(make_napari_viewer):
"""
Test the get data method.
"""
# make the label image
label_image, feature_table = make_labels_layer_with_features()
viewer = make_napari_viewer()
labels_layer = viewer.add_labels(label_image, features=feature_table)
scatter_widget = FeaturesScatterWidget(viewer)
# select the labels layer
viewer.layers.selection = [labels_layer]
x_column = "feature_0"
scatter_widget.x_axis_key = x_column
y_column = "feature_2"
scatter_widget.y_axis_key = y_column
x, y, x_axis_name, y_axis_name = scatter_widget._get_data()
np.testing.assert_allclose(x, feature_table[x_column])
np.testing.assert_allclose(y, np.stack(feature_table[y_column]))
assert x_axis_name == x_column
assert y_axis_name == y_column
def test_get_valid_axis_keys(make_napari_viewer):
"""
Test the values returned from _get_valid_keys() when there
are valid keys.
"""
# make the label image
label_image, feature_table = make_labels_layer_with_features()
viewer = make_napari_viewer()
labels_layer = viewer.add_labels(label_image, features=feature_table)
scatter_widget = FeaturesScatterWidget(viewer)
viewer.layers.selection = [labels_layer]
valid_keys = scatter_widget._get_valid_axis_keys()
assert set(valid_keys) == set(feature_table.keys())
def test_get_valid_axis_keys_no_valid_keys(make_napari_viewer):
"""Test the values returned from
FeaturesScatterWidget._get_valid_keys() when there
are not valid keys.
"""
# make the label image
label_image, _ = make_labels_layer_with_features()
viewer = make_napari_viewer()
labels_layer = viewer.add_labels(label_image)
image_layer = viewer.add_image(np.random.random((100, 100)))
scatter_widget = FeaturesScatterWidget(viewer)
# no features in a label image
viewer.layers.selection = [labels_layer]
valid_keys = scatter_widget._get_valid_axis_keys()
assert set(valid_keys) == set()
# image layer doesn't have features
viewer.layers.selection = [image_layer]
valid_keys = scatter_widget._get_valid_axis_keys()
assert set(valid_keys) == set()