diff --git a/src/napari_matplotlib/tests/conftest.py b/src/napari_matplotlib/tests/conftest.py index 0788292d..b90f9ad1 100644 --- a/src/napari_matplotlib/tests/conftest.py +++ b/src/napari_matplotlib/tests/conftest.py @@ -26,6 +26,20 @@ def brain_data(): return data.brain(), {"rgb": False} +@pytest.fixture +def points_with_features_data(): + n_points = 100 + np.random.seed(10) + points_data = 100 * np.random.random((100, 2)) + points_features = { + "feature_0": np.random.random((n_points,)), + "feature_1": np.random.random((n_points,)), + "feature_2": np.random.random((n_points,)), + } + + return points_data, {"features": points_features} + + @pytest.fixture(autouse=True, scope="session") def set_strict_qt(): env_var = "NAPARI_STRICT_QT" diff --git a/src/napari_matplotlib/tests/scatter/test_scatter_features.py b/src/napari_matplotlib/tests/scatter/test_scatter_features.py index fca8a767..c211a064 100644 --- a/src/napari_matplotlib/tests/scatter/test_scatter_features.py +++ b/src/napari_matplotlib/tests/scatter/test_scatter_features.py @@ -9,29 +9,22 @@ @pytest.mark.mpl_image_compare -def test_features_scatter_widget_2D(make_napari_viewer): +def test_features_scatter_widget_2D( + make_napari_viewer, points_with_features_data +): viewer = make_napari_viewer() viewer.theme = "light" widget = FeaturesScatterWidget(viewer) - # make the points data - n_points = 100 - np.random.seed(10) - points_data = 100 * np.random.random((100, 2)) - points_features = { - "feature_0": np.random.random((n_points,)), - "feature_1": np.random.random((n_points,)), - "feature_2": np.random.random((n_points,)), - } - - viewer.add_points(points_data, features=points_features) + viewer.add_points( + points_with_features_data[0], **points_with_features_data[1] + ) + assert len(viewer.layers) == 1 # De-select existing selection viewer.layers.selection.clear() # Select points data and chosen features - viewer.layers.selection.add( - viewer.layers["points_data"] - ) # images need to be selected + viewer.layers.selection.add(viewer.layers[0]) # images need to be selected widget.x_axis_key = "feature_0" widget.y_axis_key = "feature_1" diff --git a/src/napari_matplotlib/tests/test_layer_changes.py b/src/napari_matplotlib/tests/test_layer_changes.py index 1cdf299f..87baebfd 100644 --- a/src/napari_matplotlib/tests/test_layer_changes.py +++ b/src/napari_matplotlib/tests/test_layer_changes.py @@ -6,7 +6,11 @@ import pytest from napari.viewer import Viewer -from napari_matplotlib import HistogramWidget, SliceWidget +from napari_matplotlib import ( + FeaturesScatterWidget, + HistogramWidget, + SliceWidget, +) from napari_matplotlib.base import NapariMPLWidget from napari_matplotlib.tests.helpers import ( assert_figures_equal, @@ -39,10 +43,48 @@ def assert_one_layer_plot_changes( by `widget_cls` also changes. """ widget = widget_cls(viewer) - viewer.add_image(data1[0], **data1[1]) viewer.add_image(data2[0], **data2[1]) + assert_plot_changes(viewer, widget) + +@pytest.mark.parametrize("widget_cls", [FeaturesScatterWidget]) +def test_change_features_layer( + make_napari_viewer, points_with_features_data, widget_cls +): + """ + Test all widgets that take one layer with features as input to make sure the + plot changes when the napari layer selection changes. + """ + viewer = make_napari_viewer() + assert_features_plot_changes(viewer, widget_cls, points_with_features_data) + + +def assert_features_plot_changes( + viewer: Viewer, + widget_cls: Type[NapariMPLWidget], + data: Tuple[npt.NDArray[np.generic], Dict[str, Any]], +) -> None: + """ + When the selected layer is changed, make sure the plot generated + by `widget_cls` also changes. + """ + widget = widget_cls(viewer) + viewer.add_points(data[0], **data[1]) + # Change the features data for the second layer + data[1]["features"] = { + name: data + 1 for name, data in data[1]["features"].items() + } + viewer.add_points(data[0], **data[1]) + assert_plot_changes(viewer, widget) + + +def assert_plot_changes(viewer: Viewer, widget: NapariMPLWidget) -> None: + """ + Assert that a widget plot changes when the layer selection + is changed. The passed viewer must already have two layers + loaded. + """ # Select first layer viewer.layers.selection.clear() viewer.layers.selection.add(viewer.layers[0])