diff --git a/src/nimbus_inference/example_dataset.py b/src/nimbus_inference/example_dataset.py index 6642085..368c9a8 100644 --- a/src/nimbus_inference/example_dataset.py +++ b/src/nimbus_inference/example_dataset.py @@ -5,6 +5,9 @@ from typing import Union import datasets from alpineer.misc_utils import verify_in_list +import zipfile +import os +import requests EXAMPLE_DATASET_REVISION: str = "main" @@ -214,4 +217,48 @@ def get_example_dataset(dataset: str, save_dir: Union[str, pathlib.Path], example_dataset.download_example_dataset() # Move the dataset over to the save_dir from the user. - example_dataset.move_example_dataset(move_dir=save_dir) \ No newline at end of file + example_dataset.move_example_dataset(move_dir=save_dir) + + +def download_and_unpack_gold_standard(save_dir: Union[str, pathlib.Path], overwrite_existing: bool = True): + """ + Downloads 'gold_standard_labelled.zip' from the Hugging Face dataset and unpacks it in the given folder + if the dataset is not already present there. + + Args: + save_dir (Union[str, Path]): The path to save the dataset files in. + overwrite_existing (bool): The option to overwrite existing files. Defaults to True. + """ + url = "https://huggingface.co/datasets/JLrumberger/Pan-Multiplex-Gold-Standard/resolve/main/gold_standard_labelled.zip" + save_dir = pathlib.Path(save_dir) + zip_path = save_dir / "gold_standard_labelled.zip" + + # Create the save directory if it doesn't exist + save_dir.mkdir(parents=True, exist_ok=True) + + # Check if the dataset is already present + if zip_path.exists() and not overwrite_existing: + print(f"{zip_path} already exists. Skipping download.") + return + + # Download the zip file + print(f"Downloading {url} to {zip_path}...") + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(zip_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print(f"Downloaded {zip_path}") + + # Unpack the zip file + print(f"Unpacking {zip_path}...") + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(save_dir) + + print(f"Unpacked to {save_dir}") + + # Optionally, remove the zip file after unpacking + os.remove(zip_path) + print(f"Removed {zip_path}") diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 159096c..a394bff 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -36,9 +36,18 @@ def segmentation_naming_convention(fov_path): Returns: str: paths to segmentation fovs """ - fov_name = os.path.basename(fov_path).replace(".ome.tiff", "") - return os.path.join(deepcell_output_dir, fov_name + "_whole_cell.tiff") - + fov_name = os.path.basename(fov_path) + # remove suffix + fov_name = Path(fov_name).stem + # find all fnames which contain a superset of the fov_name + fnames = os.listdir(deepcell_output_dir) + # use re instead of glob + fnames = [os.path.join(deepcell_output_dir, f) for f in fnames if fov_name in f] + if len(fnames) == 0: + raise ValueError(f"No segmentation data found for fov {fov_name}") + if len(fnames) > 1: + raise ValueError(f"Multiple segmentation data found for fov {fov_name}") + return fnames[0] return segmentation_naming_convention diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index 49e10c7..700cee7 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -717,4 +717,59 @@ def __getitem__(self, idx): input_data = sample[:2] groundtruth = sample[2:3] inst_mask = sample[3:] - return input_data, groundtruth, inst_mask, self.keys[idx] \ No newline at end of file + return input_data, groundtruth, inst_mask, self.keys[idx] + + +class InteractiveDataset(object): + """Dataset for the InteractiveViewer class. This dataset class stores multiple objects of type + MultiplexedDataset, and allows to select a dataset and use its method for reading fovs and + channels from it. + + Args: + datasets (dict): dictionary with dataset names as keys and dataset objects as values + """ + def __init__(self, datasets: dict): + self.datasets = datasets + self.dataset_names = list(datasets.keys()) + self.dataset = None + + def set_dataset(self, dataset_name: str): + """Set the active dataset + + Args: + dataset_name (str): name of the dataset + """ + self.dataset = self.datasets[dataset_name] + return self.dataset + + def get_channel(self, fov: str, channel: str): + """Get a channel from a fov + + Args: + fov (str): name of a fov + channel (str): channel name + Returns: + np.array: channel image + """ + return self.dataset.get_channel(fov, channel) + + def get_segmentation(self, fov: str): + """Get the instance mask for a fov + + Args: + fov (str): name of a fov + Returns: + np.array: instance mask + """ + return self.dataset.get_segmentation(fov) + + def get_groundtruth(self, fov: str, channel: str): + """Get the groundtruth for a fov / channel combination + + Args: + fov (str): name of a fov + channel (str): channel name + Returns: + np.array: groundtruth activity mask (0: negative, 1: positive, 2: ambiguous) + """ + return self.dataset.get_groundtruth(fov, channel) diff --git a/src/nimbus_inference/viewer_widget.py b/src/nimbus_inference/viewer_widget.py index b09062e..e113ffd 100644 --- a/src/nimbus_inference/viewer_widget.py +++ b/src/nimbus_inference/viewer_widget.py @@ -8,7 +8,7 @@ from natsort import natsorted from skimage.segmentation import find_boundaries from skimage.transform import rescale -from nimbus_inference.utils import MultiplexDataset +from nimbus_inference.utils import MultiplexDataset, InteractiveDataset from mpl_interactions import panhandler import matplotlib.pyplot as plt @@ -289,7 +289,7 @@ class InteractiveImageDuo(widgets.Image): title_left (str): Title of left image. title_right (str): Title of right image. """ - def __init__(self, figsize=(10, 5), title_left='Multiplexed image', title_right='Prediction'): + def __init__(self, figsize=(10, 5), title_left='Multiplexed image', title_right='Groundtruth'): super().__init__() self.title_left = title_left self.title_right = title_right @@ -359,30 +359,37 @@ def update_right_image(self, image): Args: image (np.array): Image to display. """ - self.ax[1].imshow(image) + self.ax[1].imshow(image, vmin=0, vmax=255) self.ax[1].title.set_text(self.title_right) self.ax[1].set_xticks([]) self.ax[1].set_yticks([]) self.fig.canvas.draw_idle() -class NimbusInteractiveViewer(NimbusViewer): - """Interactive viewer for Nimbus application. +class NimbusInteractiveGTViewer(NimbusViewer): + """Interactive viewer for Nimbus application that shows input data and ground truth + side by side. Args: dataset (MultiplexDataset): dataset object output_dir (str): Path to directory containing output of Nimbus application. - segmentation_naming_convention (fn): Function that maps input path to segmentation path - img_width (str): Width of images in viewer. - suffix (str): Suffix of images in dataset. - max_resolution (tuple): Maximum resolution of images in viewer. + figsize (tuple): Size of figure. """ def __init__( - self, dataset: MultiplexDataset, output_dir: str, img_width='600px', suffix=".tiff", - max_resolution=(2048, 2048) + self, datasets: InteractiveDataset, output_dir, figsize=(20, 10) ): - super().__init__(dataset, output_dir, img_width, suffix, max_resolution) - self.image = InteractiveImageDuo() + super().__init__( + datasets.datasets[datasets.dataset_names[0]], output_dir + ) + self.image = InteractiveImageDuo(figsize=figsize) + self.dataset = datasets.datasets[datasets.dataset_names[0]] + self.datasets = datasets + self.dataset_select = widgets.Select( + options=datasets.dataset_names, + description='Dataset:', + disabled=False + ) + self.dataset_select.observe(self.select_dataset, names='value') def layout(self): """Creates layout for viewer.""" @@ -392,15 +399,28 @@ def layout(self): self.blue_select ]) layout = widgets.HBox([ - widgets.HBox([ + # widgets.HBox([ + self.dataset_select, self.fov_select, channel_selectors, self.overlay_checkbox, self.update_button - ]), + # ]), ]) display(layout) + def select_dataset(self, change): + """Selects dataset to display. + + Args: + change (dict): Change dictionary from ipywidgets. + """ + self.dataset = self.datasets.set_dataset(change['new']) + self.fov_names = natsorted(copy(self.dataset.fovs)) + self.fov_select.options = self.fov_names + self.select_fov(None) + + def update_img(self, image_fn, composite_image): """Updates image in viewer by saving it as png and loading it with the viewer widget. @@ -444,10 +464,6 @@ def update_composite(self): non_none = [p for p in path_dict.values() if p] if not non_none: return - composite_image = self.create_composite_image(path_dict) - composite_image, _ = self.overlay( - composite_image, add_overlay=True - ) in_composite_image = self.create_composite_from_dataset(in_path_dict) in_composite_image, seg_boundaries = self.overlay( @@ -459,6 +475,27 @@ def update_composite(self): in_composite_image = np.clip(in_composite_image*255, 0, 255).astype(np.uint8) if seg_boundaries is not None: in_composite_image[seg_boundaries] = [127, 127, 127] + + img = in_composite_image[...,0].astype(np.float32) * 0 + right_images = [] + for c, s in {'red': self.red_select.value, + 'green': self.green_select.value, + 'blue': self.blue_select.value}.items(): + if s: + composite_image = self.dataset.get_groundtruth( + self.fov_select.value, s + ) + else: + composite_image = img + composite_image = np.squeeze(composite_image).astype(np.float32) + right_images.append(composite_image) + right_images = np.stack(right_images, axis=-1) + right_images = np.clip(right_images, 0, 2) + right_images[right_images == 2] = 0.3 + right_images[seg_boundaries] = 0.0 + right_images *= 255.0 + right_images = right_images.astype(np.uint8) + # update image viewers self.update_img(self.image.update_left_image, in_composite_image) - self.update_img(self.image.update_right_image, composite_image) \ No newline at end of file + self.update_img(self.image.update_right_image, right_images) diff --git a/tests/test_viewer_widget.py b/tests/test_viewer_widget.py index 13613cc..f6bb570 100644 --- a/tests/test_viewer_widget.py +++ b/tests/test_viewer_widget.py @@ -1,7 +1,10 @@ +from nimbus_inference.viewer_widget import InteractiveImageDuo, NimbusInteractiveGTViewer from nimbus_inference.viewer_widget import NimbusViewer from nimbus_inference.nimbus import Nimbus, prep_naming_convention from nimbus_inference.utils import MultiplexDataset from tests.test_utils import prepare_ome_tif_data, prepare_tif_data +from natsort import natsorted +from copy import copy import numpy as np import tempfile import os @@ -73,3 +76,22 @@ def test_overlay(): assert composite_image.shape == (256, 256, 3) assert seg_boundaries.shape == (256, 256) assert np.unique(seg_boundaries).tolist() == [0, 1] + + +def test_InteractiveImageDuo(): + image_duo = InteractiveImageDuo( + figsize=(10, 5), title_left='Left Image', title_right='Right Image' + ) + assert isinstance(image_duo, InteractiveImageDuo) + + # Create dummy images + left_image = np.random.randint(0, 255, (256, 256), dtype=np.uint8) + right_image = np.random.randint(0, 255, (256, 256), dtype=np.uint8) + + # Update images + image_duo.update_left_image(left_image) + image_duo.update_right_image(right_image) + + # Check if images are updated + assert image_duo.ax[0].images[0].get_array().shape == (256, 256) + assert image_duo.ax[1].images[0].get_array().shape == (256, 256)