diff --git a/benchmark/__main__.py b/benchmark/__main__.py new file mode 100644 index 0000000..bf2a4dd --- /dev/null +++ b/benchmark/__main__.py @@ -0,0 +1,37 @@ +from argparse import ArgumentParser +from os.path import exists + +from benchmark.data import DataTest, SlidingWindowTest, RandomROIDatasetTest +from benchmark.training import TrainingTest, ResizeTrainingTest, SlidingTrainingTest +from mipcandy import auto_device, download_dataset, Frontend, NotionFrontend, WandBFrontend + +BENCHMARK_DATASET: str = "AbdomenCT-1K-ss1" + +if __name__ == "__main__": + tests = { + "SlidingWindow": SlidingWindowTest, + "RandomROI": RandomROIDatasetTest, + "Training": TrainingTest, + "ResizeTraining": ResizeTrainingTest, + "SlidingTraining": SlidingTrainingTest + } + parser = ArgumentParser(prog="MIP Candy Benchmark", description="MIP Candy Benchmark", + epilog="GitHub: https://github.com/ProjectNeura/MIPCandy") + parser.add_argument("test", choices=tests.keys()) + parser.add_argument("-i", "--input-folder") + parser.add_argument("-o", "--output-folder") + parser.add_argument("--num-epochs", type=int, default=100) + parser.add_argument("--device", default=None) + parser.add_argument("--front-end", choices=(None, "n", "w"), default=None) + args = parser.parse_args() + DataTest.dataset = BENCHMARK_DATASET + test = tests[args.test]( + args.input_folder, args.output_folder, args.num_epochs, args.device if args.device else auto_device(), { + None: Frontend, "n": NotionFrontend, "w": WandBFrontend + }[args.front_end] + ) + if not exists(f"{args.input_folder}/{BENCHMARK_DATASET}"): + download_dataset(f"nnunet_datasets/{BENCHMARK_DATASET}", f"{args.input_folder}/{BENCHMARK_DATASET}") + stat, err = test.run() + if not stat: + raise err diff --git a/benchmark/data.py b/benchmark/data.py new file mode 100644 index 0000000..b09dab1 --- /dev/null +++ b/benchmark/data.py @@ -0,0 +1,69 @@ +from os import makedirs +from time import time +from typing import override, Literal + +from rich.progress import Progress + +from benchmark.prototype import UnitTest +from mipcandy import NNUNetDataset, do_sliding_window, visualize3d, revert_sliding_window, JointTransform, inspect, \ + RandomROIDataset + + +class DataTest(UnitTest): + dataset: str = "AbdomenCT-1K-ss1" + transform: JointTransform | None = None + + @override + def set_up(self) -> None: + self["dataset"] = NNUNetDataset(f"{self.input_folder}/{DataTest.dataset}", transform=self.transform, + device=self.device) + self["dataset"].preload(f"{self.input_folder}/{DataTest.dataset}/preloaded") + + +class FoldedDataTest(DataTest): + fold: Literal[0, 1, 2, 3, 4, "all"] = 0 + + @override + def set_up(self) -> None: + super().set_up() + self["train_dataset"], self["val_dataset"] = self["dataset"].fold(fold=self.fold) + + +class SlidingWindowTest(DataTest): + @override + def execute(self) -> None: + image, _ = self["dataset"][0] + print(image.shape) + visualize3d(image, title="raw") + t0 = time() + windows, layout, pad = do_sliding_window(image, (128, 128, 128)) + print(f"took {time() - t0:.2f}s") + print(windows[0].shape, layout) + t0 = time() + recon = revert_sliding_window(windows, layout, pad) + print(f"took {time() - t0:.2f}s") + print(recon.shape) + visualize3d(recon, title="reconstructed") + + +class RandomROIDatasetTest(DataTest): + @override + def execute(self) -> None: + annotations = inspect(self["dataset"]) + dataset = RandomROIDataset(annotations, 2) + print(dataset.roi_shape()) + o = f"{self.output_folder}/RandomROIPreviews" + makedirs(o, exist_ok=True) + makedirs(f"{o}/images", exist_ok=True) + makedirs(f"{o}/labels", exist_ok=True) + makedirs(f"{o}/imageROIs", exist_ok=True) + makedirs(f"{o}/labelROIs", exist_ok=True) + with Progress() as progress: + task = progress.add_task("Generating Previews...", total=len(dataset)) + for idx, (image_roi, label_roi) in enumerate(dataset): + image, label = self["dataset"][idx] + visualize3d(image, title="image raw", screenshot_as=f"{o}/images/{idx}.png") + visualize3d(label.int(), title="label raw", is_label=True, screenshot_as=f"{o}/labels/{idx}.png") + visualize3d(image_roi, title="image roi", screenshot_as=f"{o}/imageROIs/{idx}.png") + visualize3d(label_roi.int(), title="label roi", is_label=True, screenshot_as=f"{o}/labelROIs/{idx}.png") + progress.update(task, advance=1) diff --git a/benchmark/prototype.py b/benchmark/prototype.py new file mode 100644 index 0000000..1708421 --- /dev/null +++ b/benchmark/prototype.py @@ -0,0 +1,41 @@ +from os import PathLike +from typing import Any + +from mipcandy import Device, Frontend + + +class UnitTest(object): + def __init__(self, input_folder: str | PathLike[str], output_folder: str | PathLike[str], num_epochs: int, + device: Device, frontend: type[Frontend]) -> None: + self.input_folder: str = input_folder + self.output_folder: str = output_folder + self.num_epochs: int = num_epochs + self.device: Device = device + self.frontend: type[Frontend] = frontend + + def set_up(self) -> None: + pass + + def execute(self) -> None: + pass + + def clean_up(self) -> None: + pass + + def run(self) -> tuple[bool, Exception | None]: + try: + self.set_up() + self.execute() + except Exception as e: + try: + self.clean_up() + except Exception as e2: + print(f"Failed to clean up after exception: {e2}") + return False, e + return True, None + + def __setitem__(self, key: str, value: Any) -> None: + setattr(self, "_x_" + key, value) + + def __getitem__(self, item: str) -> Any: + return getattr(self, "_x_" + item) diff --git a/benchmark/training.py b/benchmark/training.py new file mode 100644 index 0000000..70b6064 --- /dev/null +++ b/benchmark/training.py @@ -0,0 +1,118 @@ +from os import removedirs +from os.path import exists +from typing import override + +from monai.transforms import Resized +from torch.utils.data import DataLoader + +from benchmark.data import DataTest, FoldedDataTest +from benchmark.unet import UNetTrainer, UNetSlidingTrainer +from mipcandy import SegmentationTrainer, slide_dataset, Shape, SupervisedSWDataset, JointTransform, inspect, \ + load_inspection_annotations, RandomROIDataset + + +class TrainingTest(DataTest): + trainer: type[SegmentationTrainer] = UNetTrainer + resize: Shape = (128, 128, 128) + num_classes: int = 5 + _continue: str | None = None # internal flag for continued training + + def set_up_datasets(self) -> None: + super().set_up() + path = f"{self.output_folder}/training_test.json" + self["dataset"].device(device="cpu") + if exists(path): + annotations = load_inspection_annotations(path, self["dataset"]) + else: + annotations = inspect(self["dataset"]) + annotations.save(path) + dataset = RandomROIDataset(annotations, 2, num_patches_per_case=2) + dataset.roi_shape(roi_shape=(128, 128, 128)) + self["train_dataset"], self["val_dataset"] = dataset.fold(fold=0) + + @override + def set_up(self) -> None: + self.set_up_datasets() + train, val = self["train_dataset"], self["val_dataset"] + val.preload(f"{self.output_folder}/valPreloaded") + # train.set_transform(JointTransform(image_only=Normalize(domain=(0, 1), strict=True))) + # val.set_transform(JointTransform(image_only=Normalize(domain=(0, 1), strict=True))) + train_dataloader = DataLoader(train, batch_size=2, shuffle=True, pin_memory=True, prefetch_factor=2, + num_workers=2, persistent_workers=True) + val_dataloader = DataLoader(val, batch_size=1, shuffle=False, pin_memory=True) + trainer = self.trainer(self.output_folder, train_dataloader, val_dataloader, device=self.device) + trainer.num_classes = self.num_classes + trainer.set_frontend(self.frontend) + self["trainer"] = trainer + + @override + def execute(self) -> None: + if not self._continue: + return self["trainer"].train(self.num_epochs, note=f"Training test {self.resize}", compile_model=False, + val_score_prediction=False) + self["trainer"].recover_from(self._continue) + return self["trainer"].continue_training(self.num_epochs) + + @override + def clean_up(self) -> None: + removedirs(self["trainer"].experiment_folder()) + + +class ResizeTrainingTest(FoldedDataTest): + trainer: type[SegmentationTrainer] = UNetTrainer + resize: Shape = (256, 256, 256) + num_classes: int = 5 + + @override + def set_up(self) -> None: + self.transform = JointTransform(transform=Resized(("image", "label"), self.resize)) + super().set_up() + train_dataloader = DataLoader(self["train_dataset"], batch_size=2, shuffle=True) + val_dataloader = DataLoader(self["val_dataset"], batch_size=1, shuffle=False) + trainer = self.trainer(self.output_folder, train_dataloader, val_dataloader, recoverable=False, + profiler=True, device=self.device) + trainer.num_classes = self.num_classes + trainer.set_frontend(self.frontend) + self["trainer"] = trainer + + @override + def execute(self) -> None: + self["trainer"].train(self.num_epochs, note=f"Resize Training test {self.resize}") + + @override + def clean_up(self) -> None: + removedirs(self["trainer"].experiment_folder()) + + +class SlidingTrainingTest(TrainingTest, FoldedDataTest): + trainer: type[SegmentationTrainer] = UNetSlidingTrainer + window_shape: Shape = (128, 128, 128) + overlap: float = .5 + + @override + def set_up(self) -> None: + self.set_up_datasets() + train, val = self["train_dataset"], self["val_dataset"] + FoldedDataTest.set_up(self) + full_val = self["val_dataset"] + path = f"{self.output_folder}/val_slided" + if not exists(path): + slide_dataset(full_val, path, self.window_shape, overlap=self.overlap) + slided_val = SupervisedSWDataset(path) + train_dataloader = DataLoader(train, batch_size=2, shuffle=True) + val_dataloader = DataLoader(val, batch_size=1, shuffle=False) + trainer = self.trainer(self.output_folder, train_dataloader, val_dataloader, recoverable=False, + profiler=True, device=self.device) + trainer.set_datasets(full_val, slided_val) + trainer.num_classes = self.num_classes + trainer.overlap = self.overlap + trainer.set_frontend(self.frontend) + self["trainer"] = trainer + + @override + def execute(self) -> None: + self["trainer"].train(self.num_epochs, note="Training test with sliding window") + + @override + def clean_up(self) -> None: + removedirs(self["trainer"].experiment_folder()) diff --git a/benchmark/transforms.py b/benchmark/transforms.py new file mode 100644 index 0000000..9698436 --- /dev/null +++ b/benchmark/transforms.py @@ -0,0 +1,642 @@ +""" +MIPCandy Transform Module - nnUNet-compatible data augmentation using MONAI. + +This module provides nnUNet-style transforms built on top of MONAI's transform infrastructure. +Only implements transforms that MONAI doesn't provide natively. +""" +from __future__ import annotations + +from typing import Hashable, Sequence + +import numpy as np +import torch +from monai.config import KeysCollection +from monai.transforms import ( + Compose, + MapTransform, + OneOf, + RandAdjustContrastd, + RandAffined, + RandFlipd, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandSimulateLowResolutiond, + Randomizable, + Transform, +) +from scipy.ndimage import label as scipy_label +from skimage.morphology import ball, disk +from torch.nn.functional import conv2d, conv3d, interpolate, pad + + +# ============================================================================= +# nnUNet-specific Scalar Sampling +# ============================================================================= +class BGContrast: + """nnUNet-style contrast/gamma sampling - biased towards values around 1.""" + + def __init__(self, value_range: tuple[float, float]) -> None: + self._range: tuple[float, float] = value_range + + def __call__(self) -> float: + if np.random.random() < 0.5 and self._range[0] < 1: + return float(np.random.uniform(self._range[0], 1)) + return float(np.random.uniform(max(self._range[0], 1), self._range[1])) + + +# ============================================================================= +# Transforms MONAI doesn't have (nnUNet-specific) +# ============================================================================= +class DownsampleSegForDS(Transform): + """Downsample segmentation for deep supervision - produces list of tensors.""" + + def __init__(self, scales: Sequence[float | Sequence[float]]) -> None: + self._scales: list = list(scales) + + def __call__(self, seg: torch.Tensor) -> list[torch.Tensor]: + results = [] + for s in self._scales: + if not isinstance(s, (tuple, list)): + s = [s] * (seg.ndim - 1) + if all(i == 1 for i in s): + results.append(seg) + else: + new_shape = [round(dim * scale) for dim, scale in zip(seg.shape[1:], s)] + results.append(interpolate(seg[None].float(), new_shape, mode="nearest-exact")[0].to(seg.dtype)) + return results + + +class DownsampleSegForDSd(MapTransform): + """Dictionary version of DownsampleSegForDS.""" + + def __init__(self, keys: KeysCollection, scales: Sequence[float | Sequence[float]]) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = DownsampleSegForDS(scales) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class Convert3DTo2D(Transform): + """Convert 3D data to 2D by merging first spatial dim into channels (for anisotropic data).""" + + def __call__(self, img: torch.Tensor) -> tuple[torch.Tensor, int]: + nch = img.shape[0] + return img.reshape(img.shape[0] * img.shape[1], *img.shape[2:]), nch + + +class Convert2DTo3D(Transform): + """Convert 2D data back to 3D.""" + + def __call__(self, img: torch.Tensor, nch: int) -> torch.Tensor: + return img.reshape(nch, img.shape[0] // nch, *img.shape[1:]) + + +class Convert3DTo2Dd(MapTransform): + """Dictionary version - stores channel counts for restoration.""" + + def __init__(self, keys: KeysCollection) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = Convert3DTo2D() + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key], d[f"_nch_{key}"] = self._transform(d[key]) + return d + + +class Convert2DTo3Dd(MapTransform): + """Dictionary version - restores from stored channel counts.""" + + def __init__(self, keys: KeysCollection) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = Convert2DTo3D() + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + nch_key = f"_nch_{key}" + d[key] = self._transform(d[key], d[nch_key]) + del d[nch_key] + return d + + +class ConvertSegToRegions(Transform): + """Convert segmentation to region-based binary masks.""" + + def __init__(self, regions: Sequence[int | Sequence[int]], channel: int = 0) -> None: + self._regions: list[torch.Tensor] = [ + torch.tensor([r]) if isinstance(r, int) else torch.tensor(r) for r in regions + ] + self._channel: int = channel + + def __call__(self, seg: torch.Tensor) -> torch.Tensor: + output = torch.zeros((len(self._regions), *seg.shape[1:]), dtype=torch.bool, device=seg.device) + for i, labels in enumerate(self._regions): + if len(labels) == 1: + output[i] = seg[self._channel] == labels[0] + else: + output[i] = torch.isin(seg[self._channel], labels) + return output + + +class ConvertSegToRegionsd(MapTransform): + """Dictionary version of ConvertSegToRegions.""" + + def __init__(self, keys: KeysCollection, regions: Sequence[int | Sequence[int]], channel: int = 0) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = ConvertSegToRegions(regions, channel) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class MoveSegAsOneHotToData(Transform): + """Move segmentation channel as one-hot encoding to image (for cascade training).""" + + def __init__(self, source_channel: int, labels: Sequence[int], remove_from_seg: bool = True) -> None: + self._source_channel: int = source_channel + self._labels: list[int] = list(labels) + self._remove: bool = remove_from_seg + + def __call__(self, image: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + seg_slice = seg[self._source_channel] + onehot = torch.zeros((len(self._labels), *seg_slice.shape), dtype=image.dtype) + for i, label in enumerate(self._labels): + onehot[i][seg_slice == label] = 1 + new_image = torch.cat((image, onehot)) + if self._remove: + keep = [i for i in range(seg.shape[0]) if i != self._source_channel] + seg = seg[keep] + return new_image, seg + + +class MoveSegAsOneHotToDatad(MapTransform): + """Dictionary version of MoveSegAsOneHotToData.""" + + def __init__( + self, + image_key: str, + seg_key: str, + source_channel: int, + labels: Sequence[int], + remove_from_seg: bool = True, + ) -> None: + super().__init__([image_key, seg_key], allow_missing_keys=False) + self._image_key: str = image_key + self._seg_key: str = seg_key + self._transform = MoveSegAsOneHotToData(source_channel, labels, remove_from_seg) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + d[self._image_key], d[self._seg_key] = self._transform(d[self._image_key], d[self._seg_key]) + return d + + +class RemoveLabel(Transform): + """Replace one label value with another in segmentation.""" + + def __init__(self, label: int, set_to: int) -> None: + self._label: int = label + self._set_to: int = set_to + + def __call__(self, seg: torch.Tensor) -> torch.Tensor: + seg = seg.clone() + seg[seg == self._label] = self._set_to + return seg + + +class RemoveLabeld(MapTransform): + """Dictionary version of RemoveLabel.""" + + def __init__(self, keys: KeysCollection, label: int, set_to: int) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = RemoveLabel(label, set_to) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class RandApplyRandomBinaryOperator(Randomizable, Transform): + """Randomly apply binary morphological operations to one-hot channels.""" + + def __init__( + self, + channels: Sequence[int], + prob: float = 0.4, + strel_size: tuple[int, int] = (1, 8), + p_per_label: float = 1.0, + ) -> None: + self._channels: list[int] = list(channels) + self._prob: float = prob + self._strel_size: tuple[int, int] = strel_size + self._p_per_label: float = p_per_label + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + if self.R.random() > self._prob: + return img + + channels = self._channels.copy() + self.R.shuffle(channels) + + for c in channels: + if self.R.random() > self._p_per_label: + continue + + size = self.R.randint(self._strel_size[0], self._strel_size[1] + 1) + op = self.R.choice([_binary_dilation, _binary_erosion, _binary_opening, _binary_closing]) + + workon = img[c].to(bool) + strel = torch.from_numpy(disk(size, dtype=bool) if workon.ndim == 2 else ball(size, dtype=bool)) + result = op(workon, strel) + + added = result & (~workon) + for oc in self._channels: + if oc != c: + img[oc][added] = 0 + img[c] = result.to(img.dtype) + + return img + + +class RandApplyRandomBinaryOperatord(MapTransform, Randomizable): + """Dictionary version of RandApplyRandomBinaryOperator.""" + + def __init__( + self, + keys: KeysCollection, + channels: Sequence[int], + prob: float = 0.4, + strel_size: tuple[int, int] = (1, 8), + p_per_label: float = 1.0, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys=False) + self._transform = RandApplyRandomBinaryOperator(channels, prob, strel_size, p_per_label) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class RandRemoveConnectedComponent(Randomizable, Transform): + """Randomly remove connected components from one-hot encoding.""" + + def __init__( + self, + channels: Sequence[int], + prob: float = 0.2, + fill_with_other_p: float = 0.0, + max_coverage: float = 0.15, + p_per_label: float = 1.0, + ) -> None: + self._channels: list[int] = list(channels) + self._prob: float = prob + self._fill_p: float = fill_with_other_p + self._max_coverage: float = max_coverage + self._p_per_label: float = p_per_label + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + if self.R.random() > self._prob: + return img + + channels = self._channels.copy() + self.R.shuffle(channels) + + for c in channels: + if self.R.random() > self._p_per_label: + continue + + workon = img[c].to(bool).numpy() + if not np.any(workon): + continue + + num_voxels = int(np.prod(workon.shape)) + labeled, num_components = scipy_label(workon) + if num_components == 0: + continue + + component_sizes = {i: int((labeled == i).sum()) for i in range(1, num_components + 1)} + valid = [i for i, size in component_sizes.items() if size < num_voxels * self._max_coverage] + + if valid: + chosen = self.R.choice(valid) + mask = labeled == chosen + img[c][mask] = 0 + + if self.R.random() < self._fill_p: + others = [i for i in self._channels if i != c] + if others: + other = self.R.choice(others) + img[other][mask] = 1 + + return img + + +class RandRemoveConnectedComponentd(MapTransform, Randomizable): + """Dictionary version of RandRemoveConnectedComponent.""" + + def __init__( + self, + keys: KeysCollection, + channels: Sequence[int], + prob: float = 0.2, + fill_with_other_p: float = 0.0, + max_coverage: float = 0.15, + p_per_label: float = 1.0, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys=False) + self._transform = RandRemoveConnectedComponent(channels, prob, fill_with_other_p, max_coverage, p_per_label) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class RandGammad(MapTransform, Randomizable): + """nnUNet-style gamma transform with invert option and retain_stats.""" + + def __init__( + self, + keys: KeysCollection, + prob: float = 0.3, + gamma: tuple[float, float] = (0.7, 1.5), + p_invert: float = 0.0, + p_per_channel: float = 1.0, + p_retain_stats: float = 1.0, + ) -> None: + super().__init__(keys, allow_missing_keys=False) + self._prob: float = prob + self._gamma: tuple[float, float] = gamma + self._p_invert: float = p_invert + self._p_per_channel: float = p_per_channel + self._p_retain_stats: float = p_retain_stats + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + if self.R.random() > self._prob: + return data + + d = dict(data) + for key in self.keys: + img = d[key] + for c in range(img.shape[0]): + if self.R.random() > self._p_per_channel: + continue + + g = BGContrast(self._gamma)() + invert = self.R.random() < self._p_invert + retain = self.R.random() < self._p_retain_stats + + if invert: + img[c] *= -1 + if retain: + mean, std = img[c].mean(), img[c].std() + + minm = img[c].min() + rnge = (img[c].max() - minm).clamp(min=1e-7) + img[c] = torch.pow((img[c] - minm) / rnge, g) * rnge + minm + + if retain: + mn_here, std_here = img[c].mean(), img[c].std().clamp(min=1e-7) + img[c] = (img[c] - mn_here) * (std / std_here) + mean + if invert: + img[c] *= -1 + + d[key] = img + return d + + +# ============================================================================= +# Binary Morphology Helpers +# ============================================================================= +def _binary_dilation(tensor: torch.Tensor, strel: torch.Tensor) -> torch.Tensor: + tensor_f = tensor.float() + if tensor.ndim == 2: + strel_k = strel[None, None].float() + padded = pad(tensor_f[None, None], [strel.shape[-1] // 2] * 4, mode="constant", value=0) + out = conv2d(padded, strel_k) + else: + strel_k = strel[None, None].float() + padded = pad(tensor_f[None, None], [strel.shape[-1] // 2] * 6, mode="constant", value=0) + out = conv3d(padded, strel_k) + return (out > 0).squeeze(0).squeeze(0) + + +def _binary_erosion(tensor: torch.Tensor, strel: torch.Tensor) -> torch.Tensor: + return ~_binary_dilation(~tensor, strel) + + +def _binary_opening(tensor: torch.Tensor, strel: torch.Tensor) -> torch.Tensor: + return _binary_dilation(_binary_erosion(tensor, strel), strel) + + +def _binary_closing(tensor: torch.Tensor, strel: torch.Tensor) -> torch.Tensor: + return _binary_erosion(_binary_dilation(tensor, strel), strel) + + +# ============================================================================= +# Factory Functions - nnUNet-style Pipelines using MONAI +# ============================================================================= +def training_transforms( + keys: tuple[str, str] = ("image", "label"), + patch_size: tuple[int, ...] = (128, 128, 128), + rotation: tuple[float, float] = (-30 / 360 * 2 * np.pi, 30 / 360 * 2 * np.pi), + scale: tuple[float, float] = (0.7, 1.4), + mirror_axes: tuple[int, ...] | None = (0, 1, 2), + do_dummy_2d: bool = False, + deep_supervision_scales: Sequence[float] | None = None, + is_cascaded: bool = False, + foreground_labels: Sequence[int] | None = None, + regions: Sequence[int | Sequence[int]] | None = None, + ignore_label: int | None = None, +) -> Compose: + """ + Create nnUNet-style training transforms using MONAI infrastructure. + + Args: + keys: (image_key, label_key) for dictionary transforms + patch_size: spatial size of output patches + rotation: (min, max) rotation in radians + scale: (min, max) scale factors + mirror_axes: axes to randomly flip, None to disable + do_dummy_2d: use pseudo-2D augmentation for anisotropic data + deep_supervision_scales: scales for deep supervision downsampling + is_cascaded: enable cascade training transforms + foreground_labels: labels for cascade one-hot encoding + regions: region definitions for region-based training + ignore_label: label to treat as ignore + + Returns: + Composed MONAI transforms + """ + image_key, label_key = keys + transforms: list = [] + + # Pseudo-2D for anisotropic data + if do_dummy_2d: + transforms.append(Convert3DTo2Dd(keys=[image_key, label_key])) + + # Spatial transforms (rotation, scaling) - using MONAI RandAffine + transforms.append( + RandAffined( + keys=[image_key, label_key], + prob=0.2, + rotate_range=[rotation] * 3 if len(patch_size) == 3 else [rotation], + scale_range=[(s - 1, s - 1) for s in scale], # MONAI uses additive range + mode=["bilinear", "nearest"], + padding_mode="zeros", + ) + ) + + if do_dummy_2d: + transforms.append(Convert2DTo3Dd(keys=[image_key, label_key])) + + # Intensity transforms - MONAI versions + transforms.append(RandGaussianNoised(keys=[image_key], prob=0.1, mean=0.0, std=0.1)) + transforms.append(RandGaussianSmoothd(keys=[image_key], prob=0.2, sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0))) + transforms.append(RandScaleIntensityd(keys=[image_key], prob=0.15, factors=0.25)) # multiplicative brightness + transforms.append(RandAdjustContrastd(keys=[image_key], prob=0.15, gamma=(0.75, 1.25))) + transforms.append(RandSimulateLowResolutiond(keys=[image_key], prob=0.25, zoom_range=(0.5, 1.0))) + + # Gamma transforms (nnUNet-specific with invert option) + transforms.append(RandGammad(keys=[image_key], prob=0.1, gamma=(0.7, 1.5), p_invert=1.0, p_retain_stats=1.0)) + transforms.append(RandGammad(keys=[image_key], prob=0.3, gamma=(0.7, 1.5), p_invert=0.0, p_retain_stats=1.0)) + + # Mirror/Flip + if mirror_axes: + for axis in mirror_axes: + transforms.append(RandFlipd(keys=[image_key, label_key], prob=0.5, spatial_axis=axis)) + + # Remove invalid labels + transforms.append(RemoveLabeld(keys=[label_key], label=-1, set_to=0)) + + # Cascade training + if is_cascaded and foreground_labels: + transforms.append( + MoveSegAsOneHotToDatad( + image_key=image_key, + seg_key=label_key, + source_channel=1, + labels=foreground_labels, + remove_from_seg=True, + ) + ) + cascade_channels = list(range(-len(foreground_labels), 0)) + transforms.append( + RandApplyRandomBinaryOperatord(keys=[image_key], channels=cascade_channels, prob=0.4, strel_size=(1, 8)) + ) + transforms.append( + RandRemoveConnectedComponentd(keys=[image_key], channels=cascade_channels, prob=0.2, max_coverage=0.15) + ) + + # Region-based training + if regions: + region_list = list(regions) + ([ignore_label] if ignore_label is not None else []) + transforms.append(ConvertSegToRegionsd(keys=[label_key], regions=region_list, channel=0)) + + # Deep supervision + if deep_supervision_scales: + transforms.append(DownsampleSegForDSd(keys=[label_key], scales=deep_supervision_scales)) + + return Compose(transforms) + + +def validation_transforms( + keys: tuple[str, str] = ("image", "label"), + deep_supervision_scales: Sequence[float] | None = None, + is_cascaded: bool = False, + foreground_labels: Sequence[int] | None = None, + regions: Sequence[int | Sequence[int]] | None = None, + ignore_label: int | None = None, +) -> Compose: + """ + Create nnUNet-style validation transforms using MONAI infrastructure. + + Args: + keys: (image_key, label_key) for dictionary transforms + deep_supervision_scales: scales for deep supervision downsampling + is_cascaded: enable cascade training transforms + foreground_labels: labels for cascade one-hot encoding + regions: region definitions for region-based training + ignore_label: label to treat as ignore + + Returns: + Composed MONAI transforms + """ + image_key, label_key = keys + transforms: list = [] + + transforms.append(RemoveLabeld(keys=[label_key], label=-1, set_to=0)) + + if is_cascaded and foreground_labels: + transforms.append( + MoveSegAsOneHotToDatad( + image_key=image_key, + seg_key=label_key, + source_channel=1, + labels=foreground_labels, + remove_from_seg=True, + ) + ) + + if regions: + region_list = list(regions) + ([ignore_label] if ignore_label is not None else []) + transforms.append(ConvertSegToRegionsd(keys=[label_key], regions=region_list, channel=0)) + + if deep_supervision_scales: + transforms.append(DownsampleSegForDSd(keys=[label_key], scales=deep_supervision_scales)) + + return Compose(transforms) + + +# ============================================================================= +# Re-export MONAI transforms for convenience +# ============================================================================= +__all__ = [ + # MONAI re-exports + "Compose", + "OneOf", + "RandAffined", + "RandFlipd", + "RandGaussianNoised", + "RandGaussianSmoothd", + "RandScaleIntensityd", + "RandAdjustContrastd", + "RandSimulateLowResolutiond", + # nnUNet-specific + "BGContrast", + "DownsampleSegForDS", + "DownsampleSegForDSd", + "Convert3DTo2D", + "Convert3DTo2Dd", + "Convert2DTo3D", + "Convert2DTo3Dd", + "ConvertSegToRegions", + "ConvertSegToRegionsd", + "MoveSegAsOneHotToData", + "MoveSegAsOneHotToDatad", + "RemoveLabel", + "RemoveLabeld", + "RandGammad", + "RandApplyRandomBinaryOperator", + "RandApplyRandomBinaryOperatord", + "RandRemoveConnectedComponent", + "RandRemoveConnectedComponentd", + # Factory functions + "training_transforms", + "validation_transforms", +] diff --git a/benchmark/unet.py b/benchmark/unet.py new file mode 100644 index 0000000..6d8ae77 --- /dev/null +++ b/benchmark/unet.py @@ -0,0 +1,36 @@ +from typing import override + +from monai.networks.nets import DynUNet +from torch import nn + +from mipcandy import SegmentationTrainer, SlidingTrainer, AmbiguousShape + + +class UNetTrainer(SegmentationTrainer): + num_classes = 5 + deep_supervision = False + deep_supervision_scales = [1, 2, 4] + include_background: bool = False + + @override + def build_network(self, example_shape: AmbiguousShape) -> nn.Module: + kernels = [[3, 3, 3]] * 6 + strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + filters = [32, 64, 128, 256, 512, 1024] + return DynUNet( + spatial_dims=3, + in_channels=example_shape[0], + out_channels=self.num_classes, + kernel_size=kernels, + strides=strides, + upsample_kernel_size=strides[1:], + filters=filters, + norm_name="INSTANCE", + deep_supervision=self.deep_supervision, + deep_supr_num=2, + res_block=True + ) + + +class UNetSlidingTrainer(UNetTrainer, SlidingTrainer): + pass diff --git a/mipcandy/__init__.py b/mipcandy/__init__.py index 1c2ce01..ff93954 100644 --- a/mipcandy/__init__.py +++ b/mipcandy/__init__.py @@ -11,6 +11,7 @@ dice_similarity_coefficient_multiclass, soft_dice_coefficient, accuracy_binary, accuracy_multiclass, \ precision_binary, precision_multiclass, recall_binary, recall_multiclass, iou_binary, iou_multiclass from mipcandy.presets import * +from mipcandy.profiler import ProfilerFrame, Profiler from mipcandy.run import config from mipcandy.sanity_check import num_trainable_params, model_complexity_info, SanityCheckResult, sanity_check from mipcandy.training import TrainerToolbox, Trainer diff --git a/mipcandy/common/module/__init__.py b/mipcandy/common/module/__init__.py index 4217de7..6512ece 100644 --- a/mipcandy/common/module/__init__.py +++ b/mipcandy/common/module/__init__.py @@ -1,2 +1,3 @@ from mipcandy.common.module.conv import ConvBlock2d, ConvBlock3d, WSConv2d, WSConv3d -from mipcandy.common.module.preprocess import Pad2d, Pad3d, Restore2d, Restore3d, Normalize, ColorizeLabel +from mipcandy.common.module.preprocess import Pad2d, Pad3d, Restore2d, Restore3d, PadTo, Normalize, CTNormalize, \ + ColorizeLabel diff --git a/mipcandy/common/module/preprocess.py b/mipcandy/common/module/preprocess.py index 284b895..01fabf6 100644 --- a/mipcandy/common/module/preprocess.py +++ b/mipcandy/common/module/preprocess.py @@ -4,7 +4,7 @@ import torch from torch import nn -from mipcandy.types import Colormap, Shape2d, Shape3d, Paddings2d, Paddings3d, Paddings +from mipcandy.types import Colormap, Shape2d, Shape3d, Shape, Paddings2d, Paddings3d, Paddings def reverse_paddings(paddings: Paddings) -> Paddings: @@ -124,17 +124,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x[:, pad_d0: d - pad_d1, pad_h0: h - pad_h1, pad_w0: w - pad_w1] +class PadTo(Pad): + def __init__(self, min_shape: Shape, *, value: int = 0, mode: str = "constant", batch: bool = True) -> None: + super().__init__(value=value, mode=mode, batch=batch) + self._min_shape: Shape = min_shape + self._pad2d: Pad2d = Pad2d(min_shape[0], value=value, mode=mode, batch=batch) + self._pad3d: Pad3d = Pad3d(min_shape[0], value=value, mode=mode, batch=batch) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (self._pad2d(x) if x.ndim == (4 if self.batch else 3) else self._pad3d(x)) if any( + x.shape[i + (2 if self.batch else 1)] < min_size for i, min_size in enumerate(self._min_shape)) else x + + class Normalize(nn.Module): def __init__(self, *, domain: tuple[float | None, float | None] = (0, None), strict: bool = False, - method: Literal["linear", "intercept", "cut"] = "linear") -> None: + method: Literal["linear", "intercept", "cut", "zscore"] = "linear") -> None: super().__init__() self._domain: tuple[float | None, float | None] = domain self._strict: bool = strict - self._method: Literal["linear", "intercept", "cut"] = method + self._method: Literal["linear", "intercept", "cut", "zscore"] = method + self.requires_grad_(False) def forward(self, x: torch.Tensor) -> torch.Tensor: left, right = self._domain - if left is None and right is None: + if left is None and right is None and self._method != "zscore": return x r_l, r_r = x.min(), x.max() match self._method: @@ -165,6 +178,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if right is not None: x = x.clamp(max=right) return x + case "zscore": + if left is not None or right is not None: + raise ValueError("Method \"zscore\" cannot have fixed ends") + return (x - x.mean()) / max(x.std(), torch.tensor(1e-8, device=x.device)) + + +class CTNormalize(nn.Module): + def __init__(self, mean_intensity: float, std_intensity: float, lower_bound: float, upper_bound: float) -> None: + super().__init__() + self._mean_intensity: float = mean_intensity + self._std_intensity: float = std_intensity + self._lower_bound: float = lower_bound + self._upper_bound: float = upper_bound + self.requires_grad_(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (x.clip(self._lower_bound, self._upper_bound) - self._mean_intensity) / max(self._std_intensity, 1e-8) class ColorizeLabel(nn.Module): @@ -178,6 +208,7 @@ def __init__(self, *, colormap: Colormap | None = None, batch: bool = True) -> N colormap.append([r * 32, g * 32, 255 - b * 32]) self._colormap: torch.Tensor = torch.tensor(colormap) self._batch: bool = batch + self.requires_grad_(False) def forward(self, x: torch.Tensor) -> torch.Tensor: if not self._batch: diff --git a/mipcandy/common/optim/__init__.py b/mipcandy/common/optim/__init__.py index f7eb905..75f46c2 100644 --- a/mipcandy/common/optim/__init__.py +++ b/mipcandy/common/optim/__init__.py @@ -1,2 +1,2 @@ -from mipcandy.common.optim.loss import FocalBCEWithLogits, DiceBCELossWithLogits -from mipcandy.common.optim.lr_scheduler import AbsoluteLinearLR +from mipcandy.common.optim.loss import FocalBCEWithLogits, DiceCELossWithLogits, DiceBCELossWithLogits +from mipcandy.common.optim.lr_scheduler import AbsoluteLinearLR, PolyLRScheduler diff --git a/mipcandy/common/optim/loss.py b/mipcandy/common/optim/loss.py index 6308ab6..2a8faf3 100644 --- a/mipcandy/common/optim/loss.py +++ b/mipcandy/common/optim/loss.py @@ -3,8 +3,8 @@ import torch from torch import nn -from mipcandy.data import convert_ids_to_logits -from mipcandy.metrics import do_reduction, soft_dice_coefficient +from mipcandy.data import convert_ids_to_logits, convert_logits_to_ids +from mipcandy.metrics import do_reduction, soft_dice_coefficient, dice_similarity_coefficient_binary class FocalBCEWithLogits(nn.Module): @@ -24,25 +24,79 @@ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: return do_reduction(loss, self.reduction) -class DiceBCELossWithLogits(nn.Module): - def __init__(self, num_classes: int, *, lambda_bce: float = .5, lambda_soft_dice: float = 1, - smooth: float = 1e-5, include_background: bool = True) -> None: +class _Loss(nn.Module): + def __init__(self, include_background: bool) -> None: super().__init__() + self.validation_mode: bool = False + self.include_background: bool = include_background + + def forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: + if not self.validation_mode: + return self._forward(masks, labels) + with torch.no_grad(): + c, metrics = self._forward(masks, labels) + masks = convert_logits_to_ids(masks) + dice = 0 + for i in range(0 if self.include_background else 1, self.num_classes): + class_dice = dice_similarity_coefficient_binary(masks == i, labels == i).item() + dice += class_dice + metrics[f"dice {i}"] = class_dice + metrics["dice"] = dice / (self.num_classes - (0 if self.include_background else 1)) + return c, metrics + + +class _SegmentationLoss(_Loss): + def __init__(self, num_classes: int, include_background: bool) -> None: + super().__init__(include_background) self.num_classes: int = num_classes + + def logitfy(self, labels: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + if self.num_classes != 1 and labels.shape[1] == 1: + d = labels.ndim - 2 + if d not in (1, 2, 3): + raise ValueError(f"Expected labels to be 1D, 2D, or 3D, got {d} spatial dimensions") + return convert_ids_to_logits(labels.int(), d, self.num_classes) + return labels.float() + + +class DiceCELossWithLogits(_SegmentationLoss): + def __init__(self, num_classes: int, *, lambda_ce: float = 1, lambda_soft_dice: float = 1, + smooth: float = 1e-5, include_background: bool = True) -> None: + super().__init__(num_classes, include_background) + self.lambda_ce: float = lambda_ce + self.lambda_soft_dice: float = lambda_soft_dice + self.smooth: float = smooth + + def _forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: + labels = self.logitfy(labels) + ce = nn.functional.cross_entropy(masks, labels) + masks = masks.softmax(1) + if not self.include_background: + masks = masks[:, 1:] + labels = labels[:, 1:] + soft_dice = soft_dice_coefficient(masks, labels, smooth=self.smooth) + metrics = {"soft dice": soft_dice.item(), "ce loss": ce.item()} + c = self.lambda_ce * ce + self.lambda_soft_dice * (1 - soft_dice) + return c, metrics + + +class DiceBCELossWithLogits(_SegmentationLoss): + def __init__(self, *, lambda_bce: float = 1, lambda_soft_dice: float = 1, + smooth: float = 1e-5, include_background: bool = True) -> None: + super().__init__(1, include_background) self.lambda_bce: float = lambda_bce self.lambda_soft_dice: float = lambda_soft_dice self.smooth: float = smooth - self.include_background: bool = include_background - def forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: - if self.num_classes != 1 and labels.shape[1] == 1: - d = labels.ndim - 2 - if d not in (1, 2, 3): - raise ValueError(f"Expected labels to be 1D, 2D, or 3D, got {d} spatial dimensions") - labels = convert_ids_to_logits(labels.int(), d, self.num_classes) - labels = labels.float() - bce = nn.functional.binary_cross_entropy_with_logits(masks, labels) - masks = masks.sigmoid() - soft_dice = soft_dice_coefficient(masks, labels, smooth=self.smooth, include_background=self.include_background) + def _forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: + labels = self.logitfy(labels) + if not self.include_background: + masks = masks[:, 1:] + labels = labels[:, 1:] + bce = nn.functional.binary_cross_entropy(masks, labels) + masks.sigmoid_() + soft_dice = soft_dice_coefficient(masks, labels, smooth=self.smooth) + metrics = {"soft dice": soft_dice.item(), "bce loss": bce.item()} c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - soft_dice) - return c, {"soft dice": soft_dice.item(), "bce loss": bce.item()} + return c, metrics diff --git a/mipcandy/common/optim/lr_scheduler.py b/mipcandy/common/optim/lr_scheduler.py index 1d4d835..dcb8b63 100644 --- a/mipcandy/common/optim/lr_scheduler.py +++ b/mipcandy/common/optim/lr_scheduler.py @@ -18,13 +18,13 @@ def __init__(self, optimizer: optim.Optimizer, k: float, b: float, *, min_lr: fl self._restart_step: int = 0 super().__init__(optimizer, last_epoch) - def _interp(self, step: int) -> float: - step -= self._restart_step - r = self._k * step + self._b + def _interp(self, epoch: int) -> float: + epoch -= self._restart_step + r = self._k * epoch + self._b if r < self._min_lr: if self._restart: - self._restart_step = step - return self._interp(step) + self._restart_step = epoch + return self._interp(epoch) return self._min_lr return r @@ -32,3 +32,20 @@ def _interp(self, step: int) -> float: def get_lr(self) -> list[float]: target = self._interp(self.last_epoch) return [target for _ in self.optimizer.param_groups] + + +class PolyLRScheduler(optim.lr_scheduler.LRScheduler): + def __init__(self, optimizer: optim.Optimizer, initial_lr: float, max_steps: int, *, exponent: float = .9, + last_epoch: int = -1) -> None: + self._initial_lr: float = initial_lr + self._max_steps: int = max_steps + self._exponent: float = exponent + super().__init__(optimizer, last_epoch) + + def _interp(self, epoch: int) -> float: + return self._initial_lr * (1 - epoch / self._max_steps) ** self._exponent + + @override + def get_lr(self) -> list[float]: + target = self._interp(self.last_epoch) + return [target for _ in self.optimizer.param_groups] diff --git a/mipcandy/data/__init__.py b/mipcandy/data/__init__.py index 3b9de74..68bfd02 100644 --- a/mipcandy/data/__init__.py +++ b/mipcandy/data/__init__.py @@ -5,7 +5,7 @@ from mipcandy.data.geometric import ensure_num_dimensions, orthographic_views, aggregate_orthographic_views, crop from mipcandy.data.inspection import InspectionAnnotation, InspectionAnnotations, load_inspection_annotations, \ inspect, ROIDataset, RandomROIDataset -from mipcandy.data.io import fast_save, fast_load, resample_to_isotropic, load_image, save_image +from mipcandy.data.io import fast_save, fast_load, resample_to_isotropic, load_image, save_image, empty_cache from mipcandy.data.sliding_window import do_sliding_window, revert_sliding_window, slide_dataset, \ UnsupervisedSWDataset, SupervisedSWDataset from mipcandy.data.transform import JointTransform, MONAITransform diff --git a/mipcandy/data/convertion.py b/mipcandy/data/convertion.py index dc5683d..f6c80f2 100644 --- a/mipcandy/data/convertion.py +++ b/mipcandy/data/convertion.py @@ -19,7 +19,7 @@ def convert_ids_to_logits(ids: torch.Tensor, d: Literal[1, 2, 3], num_classes: i def convert_logits_to_ids(logits: torch.Tensor, *, channel_dim: int = 1) -> torch.Tensor: - return logits.max(channel_dim, keepdim=True).indices.int() + return logits.argmax(channel_dim, keepdim=True) def auto_convert(image: torch.Tensor) -> torch.Tensor: diff --git a/mipcandy/data/dataset.py b/mipcandy/data/dataset.py index f21640f..b8fdac4 100644 --- a/mipcandy/data/dataset.py +++ b/mipcandy/data/dataset.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod from json import dump -from os import PathLike, listdir, makedirs, rmdir +from math import log10 +from os import PathLike, listdir, makedirs from os.path import exists from random import choices from shutil import copy2 @@ -8,6 +9,7 @@ import torch from pandas import DataFrame +from torch import nn from torch.utils.data import Dataset from mipcandy.data.io import fast_save, fast_load, load_image @@ -66,6 +68,8 @@ def load(self, idx: int) -> T: @override def __getitem__(self, idx: int) -> T: + if idx >= len(self): + raise IndexError(f"Index {idx} out of range [0, {len(self)})") return self.load(idx) @@ -80,7 +84,8 @@ class UnsupervisedDataset(_AbstractDataset[torch.Tensor], Generic[D], metaclass= def __init__(self, images: D, *, transform: Transform | None = None, device: Device = "cpu") -> None: super().__init__(device) self._images: D = images - self._transform: Transform | None = transform.to(device) if transform else None + self._transform: Transform | None = None + self.set_transform(transform) @override def __len__(self) -> int: @@ -88,11 +93,17 @@ def __len__(self) -> int: @override def __getitem__(self, idx: int) -> torch.Tensor: - item = super().__getitem__(idx).to(self._device) + item = super().__getitem__(idx).to(self._device, non_blocking=True) if self._transform: item = self._transform(item) return item.as_tensor() if hasattr(item, "as_tensor") else item + def transform(self) -> Transform | None: + return self._transform + + def set_transform(self, transform: Transform | None) -> None: + self._transform = transform.to(self._device) if isinstance(transform, nn.Module) else transform + class SupervisedDataset(_AbstractDataset[tuple[torch.Tensor, torch.Tensor]], Generic[D], metaclass=ABCMeta): """ @@ -106,40 +117,97 @@ def __init__(self, images: D, labels: D, *, transform: JointTransform | None = N raise ValueError(f"Unmatched number of images {len(images)} and labels {len(labels)}") self._images: D = images self._labels: D = labels - self._transform: JointTransform | None = transform.to(device) if transform else None + self._transform: JointTransform | None = None + self.set_transform(transform) + self._preloaded: str = "" + self._nd: int = int(log10(len(self))) + 1 @override def __len__(self) -> int: return len(self._images) + @abstractmethod + def load_image(self, idx: int) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def load_label(self, idx: int) -> torch.Tensor: + raise NotImplementedError + + @override + def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.load_image(idx), self.load_label(idx) + @override def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - image, label = super().__getitem__(idx) - image, label = image.to(self._device), label.to(self._device) + if self._preloaded: + if idx >= len(self): + raise IndexError(f"Index {idx} out of range [0, {len(self)})") + idx = str(idx).zfill(self._nd) + image, label = fast_load(f"{self._preloaded}/images/{idx}.pt"), fast_load( + f"{self._preloaded}/labels/{idx}.pt") + else: + image, label = super().__getitem__(idx) + image, label = image.to(self._device, non_blocking=True), label.to(self._device, non_blocking=True) if self._transform: image, label = self._transform(image, label) return image.as_tensor() if hasattr(image, "as_tensor") else image, label.as_tensor() if hasattr( label, "as_tensor") else label + def image(self, idx: int) -> torch.Tensor: + return self.load_image(idx) + + def label(self, idx: int) -> torch.Tensor: + return self.load_label(idx) + + def transform(self) -> JointTransform | None: + return self._transform + + def set_transform(self, transform: JointTransform | None) -> None: + self._transform = transform.to(self._device) if transform else None + + def _construct_new(self, images: D, labels: D) -> Self: + new = self.construct_new(images, labels) + new._preloaded = self._preloaded + new._nd = self._nd + return new + @abstractmethod def construct_new(self, images: D, labels: D) -> Self: raise NotImplementedError + def preload(self, output_folder: str | PathLike[str], *, do_transform: bool = False) -> None: + if self._preloaded: + return + images_path = f"{output_folder}/images" + labels_path = f"{output_folder}/labels" + if not exists(images_path) and not exists(labels_path): + makedirs(images_path) + makedirs(labels_path) + for idx in range(len(self)): + image, label = self[idx] if do_transform else self.load(idx) + idx = str(idx).zfill(self._nd) + fast_save(image, f"{images_path}/{idx}.pt") + fast_save(label, f"{labels_path}/{idx}.pt") + if do_transform: + self._transform = None + self._preloaded = output_folder + def fold(self, *, fold: Literal[0, 1, 2, 3, 4, "all"] = "all", picker: type[KFPicker] = OrderedKFPicker) -> tuple[ Self, Self]: - indexes = picker.pick(len(self), fold) + indices = picker.pick(len(self), fold) images_train = [] labels_train = [] images_val = [] labels_val = [] for i in range(len(self)): - if i in indexes: + if i in indices: images_val.append(self._images[i]) labels_val.append(self._labels[i]) else: images_train.append(self._images[i]) labels_train.append(self._labels[i]) - return self.construct_new(images_train, labels_train), self.construct_new(images_val, labels_val) + return self._construct_new(images_train, labels_train), self._construct_new(images_val, labels_val) class DatasetFromMemory(UnsupervisedDataset[Sequence[torch.Tensor]]): @@ -158,8 +226,12 @@ def __init__(self, images: UnsupervisedDataset, labels: UnsupervisedDataset, *, super().__init__(images, labels, transform=transform, device=device) @override - def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - return self._images[idx], self._labels[idx] + def load_image(self, idx: int) -> torch.Tensor: + return self._images[idx] + + @override + def load_label(self, idx: int) -> torch.Tensor: + return self._labels[idx] @override def construct_new(self, images: UnsupervisedDataset, labels: UnsupervisedDataset) -> Self: @@ -213,14 +285,15 @@ def save_paths(self, to: str | PathLike[str]) -> None: class SimpleDataset(PathBasedUnsupervisedDataset): - def __init__(self, folder: str | PathLike[str], *, transform: Transform | None = None, + def __init__(self, folder: str | PathLike[str], is_label: bool, *, transform: Transform | None = None, device: Device = "cpu") -> None: super().__init__(sorted(listdir(folder)), transform=transform, device=device) self._folder: str = folder + self._is_label: bool = is_label @override def load(self, idx: int) -> torch.Tensor: - return self.do_load(f"{self._folder}/{self._images[idx]}", device=self._device) + return self.do_load(f"{self._folder}/{self._images[idx]}", is_label=self._is_label, device=self._device) class PathBasedSupervisedDataset(SupervisedDataset[list[str]], metaclass=ABCMeta): @@ -277,24 +350,20 @@ def _create_subset(folder: str) -> None: makedirs(folder, exist_ok=True) @override - def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - if self._split.endswith("Preloaded"): - return ( - TensorLoader.do_load(f"{self._folder}/images{self._split}/{self._images[idx]}.pt", device=self._device), - TensorLoader.do_load(f"{self._folder}/labels{self._split}/{self._labels[idx]}.pt", is_label=True, - device=self._device) - ) - image = torch.cat([self.do_load( + def load_image(self, idx: int) -> torch.Tensor: + return torch.cat([self.do_load( f"{self._folder}/images{self._split}/{path}", align_spacing=self._align_spacing, device=self._device ) for path in self._multimodal_images[idx]]) if self._multimodal_images else self.do_load( f"{self._folder}/images{self._split}/{self._images[idx]}", align_spacing=self._align_spacing, device=self._device ) - label = self.do_load( + + @override + def load_label(self, idx: int) -> torch.Tensor: + return self.do_load( f"{self._folder}/labels{self._split}/{self._labels[idx]}", is_label=True, align_spacing=self._align_spacing, device=self._device ) - return image, label def save(self, split: str | Literal["Tr", "Ts"], *, target_folder: str | PathLike[str] | None = None) -> None: target_base = target_folder if target_folder else self._folder @@ -308,20 +377,6 @@ def save(self, split: str | Literal["Tr", "Ts"], *, target_folder: str | PathLik self._split = split self._folded = False - def preload(self) -> None: - images_path = f"{self._folder}/images{self._split}Preloaded" - labels_path = f"{self._folder}/labels{self._split}Preloaded" - if not exists(images_path) or not exists(labels_path): - rmdir(images_path) - rmdir(labels_path) - makedirs(images_path) - makedirs(images_path) - for idx in range(len(self)): - image, label = self.load(idx) - fast_save(image, f"{images_path}/{self._images[idx]}.pt") - fast_save(label, f"{labels_path}/{self._labels[idx]}.pt") - self._split += "Preloaded" - @override def construct_new(self, images: list[str], labels: list[str]) -> Self: if self._folded: @@ -334,22 +389,39 @@ def construct_new(self, images: list[str], labels: list[str]) -> Self: return new -class BinarizedDataset(SupervisedDataset[D]): - def __init__(self, base: SupervisedDataset[D], positive_ids: tuple[int, ...], *, +class BinarizedDataset(SupervisedDataset[tuple[None]]): + def __init__(self, base: SupervisedDataset, positive_ids: tuple[int, ...], *, transform: JointTransform | None = None, device: Device = "cpu") -> None: - super().__init__(base._images, base._labels, transform=transform, device=device) - self._base: SupervisedDataset[D] = base + super().__init__((None,), (None,), transform=transform, device=device) + self._base: SupervisedDataset = base self._positive_ids: tuple[int, ...] = positive_ids @override - def construct_new(self, images: D, labels: D) -> Self: + def __len__(self) -> int: + return len(self._base) + + @override + def construct_new(self, images: tuple[None], labels: tuple[None]) -> Self: raise NotImplementedError @override - def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - image, label = self._base.load(idx) + def load_image(self, idx: int) -> torch.Tensor: + return self._base.load_image(idx) + + @override + def load_label(self, idx: int) -> torch.Tensor: + label = self._base.load_label(idx) for pid in self._positive_ids: label[label == pid] = -1 label[label > 0] = 0 label[label == -1] = 1 - return image, label + return label + + @override + def fold(self, *, fold: Literal[0, 1, 2, 3, 4, "all"] = "all", picker: type[KFPicker] = OrderedKFPicker) -> tuple[ + Self, Self]: + train, val = self._base.fold(fold=fold, picker=picker) + return ( + self.__class__(train, self._positive_ids, transform=self._transform, device=self._device), + self.__class__(val, self._positive_ids, transform=self._transform, device=self._device) + ) diff --git a/mipcandy/data/inspection.py b/mipcandy/data/inspection.py index 7495d4a..9df312f 100644 --- a/mipcandy/data/inspection.py +++ b/mipcandy/data/inspection.py @@ -1,7 +1,9 @@ from dataclasses import dataclass, asdict from json import dump, load +from math import ceil from os import PathLike -from typing import Sequence, override, Callable, Self, Any +from random import randint, choice +from typing import Sequence, override, Callable, Self, Any, Literal import numpy as np import torch @@ -11,8 +13,7 @@ from mipcandy.data.dataset import SupervisedDataset from mipcandy.data.geometric import crop -from mipcandy.layer import HasDevice -from mipcandy.types import Device, Shape, AmbiguousShape +from mipcandy.types import Shape, AmbiguousShape def format_bbox(bbox: Sequence[int]) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]: @@ -28,7 +29,11 @@ def format_bbox(bbox: Sequence[int]) -> tuple[int, int, int, int] | tuple[int, i class InspectionAnnotation(object): shape: AmbiguousShape foreground_bbox: tuple[int, int, int, int] | tuple[int, int, int, int, int, int] - ids: tuple[int, ...] + class_ids: tuple[int, ...] + class_counts: dict[int, int] + class_bboxes: dict[int, tuple[int, int, int, int] | tuple[int, int, int, int, int, int]] + class_locations: dict[int, tuple[tuple[int, int] | tuple[int, int, int], ...]] + spacing: Shape | None = None def foreground_shape(self) -> Shape: r = (self.foreground_bbox[1] - self.foreground_bbox[0], self.foreground_bbox[3] - self.foreground_bbox[2]) @@ -39,21 +44,15 @@ def center_of_foreground(self) -> tuple[int, int] | tuple[int, int, int]: round((self.foreground_bbox[3] + self.foreground_bbox[2]) * .5)) return r if len(self.shape) == 2 else r + (round((self.foreground_bbox[5] + self.foreground_bbox[4]) * .5),) - def to_dict(self) -> dict[str, tuple[int, ...]]: - return asdict(self) - -class InspectionAnnotations(HasDevice, Sequence[InspectionAnnotation]): - def __init__(self, dataset: SupervisedDataset, background: int, *annotations: InspectionAnnotation, - device: Device = "cpu") -> None: - super().__init__(device) +class InspectionAnnotations(Sequence[InspectionAnnotation]): + def __init__(self, dataset: SupervisedDataset, background: int, *annotations: InspectionAnnotation) -> None: self._dataset: SupervisedDataset = dataset self._background: int = background self._annotations: tuple[InspectionAnnotation, ...] = annotations self._shapes: tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape] | None = None self._foreground_shapes: tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape] | None = None self._statistical_foreground_shape: Shape | None = None - self._foreground_heatmap: torch.Tensor | None = None self._center_of_foregrounds: tuple[int, int] | tuple[int, int, int] | None = None self._foreground_offsets: tuple[int, int] | tuple[int, int, int] | None = None self._roi_shape: Shape | None = None @@ -77,7 +76,7 @@ def __len__(self) -> int: def save(self, path: str | PathLike[str]) -> None: with open(path, "w") as f: - dump({"background": self._background, "annotations": [a.to_dict() for a in self._annotations]}, f) + dump({"background": self._background, "annotations": [asdict(a) for a in self._annotations]}, f) def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], AmbiguousShape]) -> tuple[ AmbiguousShape | None, AmbiguousShape, AmbiguousShape]: @@ -129,11 +128,9 @@ def crop_foreground(self, i: int, *, expand_ratio: float = 1) -> tuple[torch.Ten return crop(image.unsqueeze(0), bbox).squeeze(0), crop(label.unsqueeze(0), bbox).squeeze(0) def foreground_heatmap(self) -> torch.Tensor: - if self._foreground_heatmap: - return self._foreground_heatmap depths, heights, widths = self.foreground_shapes() max_shape = (max(depths), max(heights), max(widths)) if depths else (max(heights), max(widths)) - accumulated_label = torch.zeros((1, *max_shape), device=self._device) + accumulated_label = torch.zeros((1, *max_shape), device=self._dataset.device()) for i, (_, label) in enumerate(self._dataset): annotation = self._annotations[i] paddings = [0, 0, 0, 0] @@ -147,8 +144,7 @@ def foreground_heatmap(self) -> torch.Tensor: accumulated_label += nn.functional.pad( crop((label != self._background).unsqueeze(0), annotation.foreground_bbox), paddings ).squeeze(0) - self._foreground_heatmap = accumulated_label.squeeze(0) - return self._foreground_heatmap + return accumulated_label.squeeze(0).detach() def center_of_foregrounds(self) -> tuple[int, int] | tuple[int, int, int]: if self._center_of_foregrounds: @@ -177,29 +173,34 @@ def set_roi_shape(self, roi_shape: Shape | None) -> None: depths, heights, widths = self.shapes() if depths: if roi_shape[0] > min(depths) or roi_shape[1] > min(heights) or roi_shape[2] > min(widths): - raise ValueError(f"ROI shape {roi_shape} exceeds minimum image shape ({min(depths)}, {min(heights)}, {min(widths)})") + raise ValueError( + f"ROI shape {roi_shape} exceeds minimum image shape ({min(depths)}, {min(heights)}, {min(widths)})") else: if roi_shape[0] > min(heights) or roi_shape[1] > min(widths): - raise ValueError(f"ROI shape {roi_shape} exceeds minimum image shape ({min(heights)}, {min(widths)})") + raise ValueError( + f"ROI shape {roi_shape} exceeds minimum image shape ({min(heights)}, {min(widths)})") self._roi_shape = roi_shape - def roi_shape(self, *, percentile: float = .95) -> Shape: + def roi_shape(self, *, clamp: bool = True, percentile: float = .95) -> Shape: if self._roi_shape: return self._roi_shape sfs = self.statistical_foreground_shape(percentile=percentile) - if len(sfs) == 2: - sfs = (None, *sfs) - depths, heights, widths = self.shapes() - roi_shape = (min(min(heights), sfs[1]), min(min(widths), sfs[2])) - if depths: - roi_shape = (min(min(depths), sfs[0]),) + roi_shape - self._roi_shape = roi_shape + if clamp: + if len(sfs) == 2: + sfs = (None, *sfs) + depths, heights, widths = self.shapes() + roi_shape = (min(min(heights), sfs[1]), min(min(widths), sfs[2])) + if depths: + roi_shape = (min(min(depths), sfs[0]),) + roi_shape + self._roi_shape = roi_shape + else: + self._roi_shape = sfs return self._roi_shape - def roi(self, i: int, *, percentile: float = .95) -> tuple[int, int, int, int] | tuple[ + def roi(self, i: int, *, clamp: bool = True, percentile: float = .95) -> tuple[int, int, int, int] | tuple[ int, int, int, int, int, int]: annotation = self._annotations[i] - roi_shape = self.roi_shape(percentile=percentile) + roi_shape = self.roi_shape(clamp=clamp, percentile=percentile) offsets = self.center_of_foregrounds_offsets() center = annotation.center_of_foreground() roi = [] @@ -211,9 +212,9 @@ def roi(self, i: int, *, percentile: float = .95) -> tuple[int, int, int, int] | roi.append(position + offset + right) return tuple(roi) - def crop_roi(self, i: int, *, percentile: float = .95) -> tuple[torch.Tensor, torch.Tensor]: + def crop_roi(self, i: int, *, clamp: bool = True, percentile: float = .95) -> tuple[torch.Tensor, torch.Tensor]: image, label = self._dataset[i] - roi = self.roi(i, percentile=percentile) + roi = self.roi(i, clamp=clamp, percentile=percentile) return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0) @@ -221,126 +222,179 @@ def _lists_to_tuples(pairs: Sequence[tuple[str, Any]]) -> dict[str, Any]: return {k: tuple(v) if isinstance(v, list) else v for k, v in pairs} +def _str_indices_to_int_indices(obj: dict[str, Any]) -> dict[int, Any]: + return {int(k): v for k, v in obj.items()} + + +def parse_inspection_annotation(obj: dict[str, Any]) -> InspectionAnnotation: + obj["class_bboxes"] = _str_indices_to_int_indices(obj["class_bboxes"]) + obj["class_locations"] = _str_indices_to_int_indices(obj["class_locations"]) + return InspectionAnnotation(**obj) + + def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDataset) -> InspectionAnnotations: with open(path) as f: obj = load(f, object_pairs_hook=_lists_to_tuples) return InspectionAnnotations(dataset, obj["background"], *( - InspectionAnnotation(**row) for row in obj["annotations"] + parse_inspection_annotation(row) for row in obj["annotations"] )) -def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console = Console()) -> InspectionAnnotations: +def bbox_from_indices(indices: torch.Tensor, num_dim: Literal[2, 3]) -> tuple[int, int, int, int]: + mins = indices.min(dim=0)[0].tolist() + maxs = indices.max(dim=0)[0].tolist() + bbox = (mins[1], maxs[1] + 1, mins[2], maxs[2] + 1) + if num_dim == 3: + bbox += (mins[3], maxs[3] + 1) + return bbox + + +def inspect(dataset: SupervisedDataset, *, background: int = 0, max_samples: int = 10000, + console: Console = Console()) -> InspectionAnnotations: r = [] - with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress: + with torch.no_grad(), Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress: task = progress.add_task("Inspecting dataset...", total=len(dataset)) - for _, label in dataset: + for idx in range(len(dataset)): + label = dataset.label(idx).int() progress.update(task, advance=1, description=f"Inspecting dataset {tuple(label.shape)}") + ndim = label.ndim - 1 indices = (label != background).nonzero() - mins = indices.min(dim=0)[0].tolist() - maxs = indices.max(dim=0)[0].tolist() - bbox = (mins[1], maxs[1] + 1, mins[2], maxs[2] + 1) + if len(indices) == 0: + r.append(InspectionAnnotation( + tuple(label.shape[1:]), (0, 0, 0, 0) if ndim == 2 else (0, 0, 0, 0, 0, 0), (), {}, {}, {}) + ) + continue + foreground_bbox = bbox_from_indices(indices, ndim) + class_ids = label.unique().tolist() + if background in class_ids: + class_ids.remove(background) + class_counts = {} + class_bboxes = {} + class_locations = {} + for class_id in [background] + class_ids: + indices = (label == class_id).nonzero() + class_counts[class_id] = len(indices) + class_bboxes[class_id] = bbox_from_indices(indices, ndim) + if len(indices) > max_samples: + target_samples = min(max_samples, len(indices)) + sampled_idx = torch.randperm(len(indices))[:target_samples] + indices = indices[sampled_idx] + class_locations[class_id] = [tuple(coord.tolist()[1:]) for coord in indices] r.append(InspectionAnnotation( - label.shape[1:], bbox if label.ndim == 3 else bbox + (mins[3], maxs[3] + 1), tuple(label.unique()) + tuple(label.shape[1:]), foreground_bbox, tuple(class_ids), class_counts, class_bboxes, class_locations )) - return InspectionAnnotations(dataset, background, *r, device=dataset.device()) + return InspectionAnnotations(dataset, background, *r) class ROIDataset(SupervisedDataset[list[int]]): - def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .95) -> None: - super().__init__(list(range(len(annotations))), list(range(len(annotations)))) + def __init__(self, annotations: InspectionAnnotations, *, clamp: bool = True, percentile: float = .95) -> None: + super().__init__(list(range(len(annotations))), list(range(len(annotations))), + transform=annotations.dataset().transform(), device=annotations.dataset().device()) self._annotations: InspectionAnnotations = annotations + self._clamp: bool = clamp self._percentile: float = percentile @override - def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self: - return self.__class__(self._annotations, percentile=self._percentile) + def construct_new(self, images: list[int], labels: list[int]) -> Self: + new = self.__class__(self._annotations, percentile=self._percentile) + new._images = images + new._labels = labels + return new + + @override + def load_image(self, idx: int) -> torch.Tensor: + raise NotImplementedError + + @override + def load_label(self, idx: int) -> torch.Tensor: + raise NotImplementedError @override def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: i = self._images[idx] if i != self._labels[idx]: raise ValueError(f"Image {i} and label {self._labels[idx]} indices do not match") - return self._annotations.crop_roi(i, percentile=self._percentile) + with torch.no_grad(): + return self._annotations.crop_roi(i, clamp=self._clamp, percentile=self._percentile) + + +def crop_and_pad(x: torch.Tensor, bbox_lbs: list[int], bbox_ubs: list[int], *, + pad_value: int | float = 0) -> torch.Tensor: + shape = x.shape[1:] + dim = len(shape) + valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)] + valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)] + slices = tuple([slice(0, x.shape[0])] + [slice(valid_bbox_lbs[i], valid_bbox_ubs[i]) for i in range(dim)]) + cropped = x[slices] + padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)] + padding_torch = [] + for left, right in reversed(padding): + padding_torch.extend([left, right]) + padded = nn.functional.pad(cropped, padding_torch, mode="constant", value=pad_value) + return padded class RandomROIDataset(ROIDataset): - def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .95, - foreground_oversample_percentage: float = .33, min_foreground_samples: int = 500, - max_foreground_samples: int = 10000, min_percent_coverage: float = .01) -> None: - super().__init__(annotations, percentile=percentile) - self._fg_oversample: float = foreground_oversample_percentage - self._min_fg_samples: int = min_foreground_samples - self._max_fg_samples: int = max_foreground_samples - self._min_coverage: float = min_percent_coverage - self._fg_locations_cache: dict[int, tuple[tuple[int, ...], ...] | None] = {} - - def _get_foreground_locations(self, idx: int) -> tuple[tuple[int, ...], ...] | None: - if idx not in self._fg_locations_cache: - _, label = self._annotations.dataset()[idx] - indices = (label != self._annotations.background()).nonzero()[:, 1:] - if len(indices) == 0: - self._fg_locations_cache[idx] = None - elif len(indices) <= self._min_fg_samples: - self._fg_locations_cache[idx] = tuple(tuple(coord.tolist()) for coord in indices) - else: - target_samples = min( - self._max_fg_samples, - max(self._min_fg_samples, int(np.ceil(len(indices) * self._min_coverage))) - ) - sampled_idx = torch.randperm(len(indices))[:target_samples] - sampled = indices[sampled_idx] - self._fg_locations_cache[idx] = tuple(tuple(coord.tolist()) for coord in sampled) - return self._fg_locations_cache[idx] - - def _random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]: - annotation = self._annotations[idx] - roi_shape = self._annotations.roi_shape(percentile=self._percentile) - roi = [] - for dim_size, patch_size in zip(annotation.shape, roi_shape): - left = patch_size // 2 - right = patch_size - left - min_center = left - max_center = dim_size - right - center = torch.randint(min_center, max_center + 1, (1,)).item() - roi.append(center - left) - roi.append(center + right) - return tuple(roi) + def __init__(self, annotations: InspectionAnnotations, batch_size: int, *, num_patches_per_case: int = 1, + oversample_rate: float = .67, clamp: bool = False, percentile: float = .5, + min_factor: int = 16) -> None: + super().__init__(annotations, clamp=clamp, percentile=percentile) + if num_patches_per_case > 1: + images = [idx for idx in self._images for _ in range(2)] + self._images, self._labels = images, images.copy() + self._batch_size: int = batch_size + self._oversample_rate: float = oversample_rate + sfs = self._annotations.statistical_foreground_shape(percentile=self._percentile) + sfs = [ceil(s / min_factor) * min_factor for s in sfs] + self._roi_shape: Shape = (min(sfs[0], 2048), min(sfs[1], 2048)) if len(sfs) == 2 else ( + min(sfs[0], 128), min(sfs[1], 128), min(sfs[2], 128)) + + def convert_idx(self, idx: int) -> int: + idx, idx2 = self._images[idx], self._labels[idx] + if idx != idx2: + raise ValueError(f"Image {idx} and label {idx2} indices do not match") + return idx + + def roi_shape(self, *, roi_shape: Shape | None = None) -> None | Shape: + if not roi_shape: + return self._roi_shape + self._roi_shape = roi_shape - def _foreground_guided_random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[ - int, int, int, int, int, int]: + @override + def construct_new(self, images: list[int], labels: list[int]) -> Self: + new = self.__class__(self._annotations, self._batch_size, oversample_rate=self._oversample_rate, + clamp=self._clamp, percentile=self._percentile) + new._images = images + new._labels = labels + new._roi_shape = self._roi_shape + return new + + def random_roi(self, idx: int, force_foreground: bool) -> tuple[list[int], list[int]]: + idx = self.convert_idx(idx) annotation = self._annotations[idx] - roi_shape = self._annotations.roi_shape(percentile=self._percentile) - foreground_locations = self._get_foreground_locations(idx) - - if foreground_locations is None or len(foreground_locations) == 0: - return self._random_roi(idx) - - fg_idx = torch.randint(0, len(foreground_locations), (1,)).item() - fg_position = foreground_locations[fg_idx] - - roi = [] - for fg_pos, dim_size, patch_size in zip(fg_position, annotation.shape, roi_shape): - left = patch_size // 2 - right = patch_size - left - center = max(left, min(fg_pos, dim_size - right)) - roi.append(center - left) - roi.append(center + right) - return tuple(roi) + roi_shape = self._roi_shape + dim = len(annotation.shape) + need_to_pad = [max(0, roi_shape[i] - annotation.shape[i]) for i in range(dim)] + lbs = [-need_to_pad[i] // 2 for i in range(dim)] + ubs = [annotation.shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - roi_shape[i] for i in range(dim)] + if force_foreground: + if len(annotation.class_ids) == 0: + bbox_lbs = [randint(lbs[j], ubs[j]) for j in range(dim)] + else: + selected_class = choice(annotation.class_ids) + selected_voxel = choice(annotation.class_locations[selected_class]) + bbox_lbs = [max(lbs[i], selected_voxel[i] - roi_shape[i] // 2) for i in range(dim)] + else: + bbox_lbs = [randint(lbs[i], ubs[i]) for i in range(dim)] + return bbox_lbs, [bbox_lbs[i] + roi_shape[i] for i in range(dim)] - @override - def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self: - return self.__class__(self._annotations, percentile=self._percentile, - foreground_oversample_percentage=self._fg_oversample, - min_foreground_samples=self._min_fg_samples, - max_foreground_samples=self._max_fg_samples, - min_percent_coverage=self._min_coverage) + def oversample_foreground(self, idx: int) -> bool: + return idx % self._batch_size >= round(self._batch_size * (1 - self._oversample_rate)) @override def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - image, label = self._annotations.dataset()[idx] - force_fg = torch.rand(1).item() < self._fg_oversample - if force_fg: - roi = self._foreground_guided_random_roi(idx) - else: - roi = self._random_roi(idx) - return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0) + force_foreground = self.oversample_foreground(idx) + lbs, ubs = self.random_roi(idx, force_foreground) + dataset = self._annotations.dataset() + idx = self.convert_idx(idx) + return crop_and_pad(dataset.image(idx), lbs, ubs), crop_and_pad(dataset.label(idx), lbs, ubs) diff --git a/mipcandy/data/io.py b/mipcandy/data/io.py index b25b3c8..d76b253 100644 --- a/mipcandy/data/io.py +++ b/mipcandy/data/io.py @@ -1,3 +1,4 @@ +from gc import collect from math import floor from os import PathLike @@ -11,7 +12,7 @@ def fast_save(x: torch.Tensor, path: str | PathLike[str]) -> None: - save_file({"payload": x}, path) + save_file({"payload": x if x.is_contiguous() else x.contiguous()}, path) def fast_load(path: str | PathLike[str], *, device: Device = "cpu") -> torch.Tensor: @@ -39,7 +40,7 @@ def load_image(path: str | PathLike[str], *, is_label: bool = False, align_spaci file = SpITK.ReadImage(path) if align_spacing: file = resample_to_isotropic(file, interpolator=SpITK.sitkNearestNeighbor if is_label else SpITK.sitkBSpline) - img = torch.tensor(SpITK.GetArrayFromImage(file), dtype=torch.float, device=device) + img = torch.tensor(SpITK.GetArrayFromImage(file), dtype=torch.long if is_label else torch.float, device=device) if path.endswith(".nii.gz") or path.endswith(".nii") or path.endswith(".mha"): img = ensure_num_dimensions(img, 4, append_before=False).permute(3, 0, 1, 2) return img.squeeze(1) if img.shape[1] == 1 else img @@ -56,3 +57,13 @@ def save_image(image: torch.Tensor, path: str | PathLike[str]) -> None: image = auto_convert(ensure_num_dimensions(image, 3)).to(torch.uint8).permute(1, 2, 0) return SpITK.WriteImage(SpITK.GetImageFromArray(image.detach().cpu().numpy(), isVector=True), path) raise NotImplementedError(f"Unsupported file type: {path}") + + +def empty_cache(device: Device) -> None: + match torch.device(device).type: + case "cpu": + collect() + case "cuda": + torch.cuda.empty_cache() + case "mps": + torch.mps.empty_cache() diff --git a/mipcandy/data/sliding_window.py b/mipcandy/data/sliding_window.py index e5c0903..4de8dfb 100644 --- a/mipcandy/data/sliding_window.py +++ b/mipcandy/data/sliding_window.py @@ -1,12 +1,16 @@ +from ast import literal_eval +from dataclasses import dataclass +from functools import reduce from math import log10 +from operator import mul from os import PathLike, makedirs, listdir from typing import override, Literal import torch from rich.console import Console from rich.progress import Progress +from torch import nn -from mipcandy.common import Pad2d, Pad3d from mipcandy.data.dataset import UnsupervisedDataset, SupervisedDataset, MergedDataset, PathBasedUnsupervisedDataset, \ TensorLoader from mipcandy.data.io import fast_save @@ -14,26 +18,41 @@ from mipcandy.types import Shape, Transform, Device -def do_sliding_window(x: torch.Tensor, window_shape: Shape, *, overlap: float = .5) -> list[torch.Tensor]: - stride = tuple(int(s * (1 + overlap)) for s in window_shape) +def do_sliding_window(x: torch.Tensor, window_shape: Shape, *, overlap: float = .5) -> tuple[ + torch.Tensor, Shape, Shape]: + stride = tuple(int(s * (1 - overlap)) for s in window_shape) ndim = len(stride) if ndim not in (2, 3): raise ValueError(f"Window shape must be 2D or 3D, got {ndim}D") + original_shape = tuple(x.shape[1:]) + padded_shape = [] + for i, size in enumerate(original_shape): + if size <= window_shape[i]: + padded_shape.append(window_shape[i]) + else: + excess = (size - window_shape[i]) % stride[i] + padded_shape.append(size if excess == 0 else (size + stride[i] - excess)) + padding_values = [] + for i in range(ndim - 1, -1, -1): + pad_total = padded_shape[i] - original_shape[i] + pad_before = pad_total // 2 + pad_after = pad_total - pad_before + padding_values.extend([pad_before, pad_after]) + x = nn.functional.pad(x, padding_values, mode="constant", value=0) if ndim == 2: - x = Pad2d(stride, batch=False)(x) x = x.unfold(1, window_shape[0], stride[0]).unfold(2, window_shape[1], stride[1]) c, n_h, n_w, win_h, win_w = x.shape x = x.permute(1, 2, 0, 3, 4).reshape(n_h * n_w, c, win_h, win_w) - return [x[i] for i in range(x.shape[0])] - x = Pad3d(stride, batch=False)(x) - x = x.unfold(1, window_shape[0], stride[0]).unfold(2, window_shape[1], stride[1]).unfold(3, window_shape[2], - stride[2]) + return x, (n_h, n_w), (original_shape[0], original_shape[1]) + x = x.unfold(1, window_shape[0], stride[0]).unfold(2, window_shape[1], stride[1]).unfold( + 3, window_shape[2], stride[2]) c, n_d, n_h, n_w, win_d, win_h, win_w = x.shape x = x.permute(1, 2, 3, 0, 4, 5, 6).reshape(n_d * n_h * n_w, c, win_d, win_h, win_w) - return [x[i] for i in range(x.shape[0])] + return x, (n_d, n_h, n_w), (original_shape[0], original_shape[1], original_shape[2]) -def revert_sliding_window(windows: list[torch.Tensor], *, overlap: float = .5) -> torch.Tensor: +def revert_sliding_window(windows: torch.Tensor, layout: Shape, original_shape: Shape, *, + overlap: float = .5) -> torch.Tensor: first_window = windows[0] ndim = first_window.ndim - 1 if ndim not in (2, 3): @@ -41,96 +60,84 @@ def revert_sliding_window(windows: list[torch.Tensor], *, overlap: float = .5) - window_shape = first_window.shape[1:] c = first_window.shape[0] stride = tuple(int(w * (1 - overlap)) for w in window_shape) - num_windows = len(windows) if ndim == 2: h_win, w_win = window_shape - import math - grid_size = math.isqrt(num_windows) - n_h = n_w = grid_size - while n_h * n_w < num_windows: - n_w += 1 - if n_h * n_w > num_windows: - for nh in range(1, num_windows + 1): - if num_windows % nh == 0: - n_h = nh - n_w = num_windows // nh - break + n_h, n_w = layout out_h = (n_h - 1) * stride[0] + h_win out_w = (n_w - 1) * stride[1] + w_win - output = torch.zeros(1, c, out_h, out_w, device=first_window.device, dtype=first_window.dtype) - weights = torch.zeros(1, 1, out_h, out_w, device=first_window.device, dtype=first_window.dtype) - idx = 0 - for i in range(n_h): - for j in range(n_w): - if idx >= num_windows: - break - h_start = i * stride[0] - w_start = j * stride[1] - output[0, :, h_start:h_start + h_win, w_start:w_start + w_win] += windows[idx] - weights[0, 0, h_start:h_start + h_win, w_start:w_start + w_win] += 1 - idx += 1 - return output / weights.clamp(min=1) - else: - d_win, h_win, w_win = window_shape - import math - grid_size = round(num_windows ** (1 / 3)) - n_d = n_h = n_w = grid_size - while n_d * n_h * n_w < num_windows: - n_w += 1 - if n_d * n_h * n_w < num_windows: - n_h += 1 - if n_d * n_h * n_w < num_windows: - n_d += 1 - for nd in range(1, num_windows + 1): - if num_windows % nd == 0: - remaining = num_windows // nd - for nh in range(1, remaining + 1): - if remaining % nh == 0: - n_d = nd - n_h = nh - n_w = remaining // nh - break - break - out_d = (n_d - 1) * stride[0] + d_win - out_h = (n_h - 1) * stride[1] + h_win - out_w = (n_w - 1) * stride[2] + w_win - output = torch.zeros(1, c, out_d, out_h, out_w, device=first_window.device, dtype=first_window.dtype) - weights = torch.zeros(1, 1, out_d, out_h, out_w, device=first_window.device, dtype=first_window.dtype) - idx = 0 - for i in range(n_d): - for j in range(n_h): - for k in range(n_w): - if idx >= num_windows: - break - d_start = i * stride[0] - h_start = j * stride[1] - w_start = k * stride[2] - output[0, :, d_start:d_start + d_win, h_start:h_start + h_win, w_start:w_start + w_win] += windows[ - idx] - weights[0, 0, d_start:d_start + d_win, h_start:h_start + h_win, w_start:w_start + w_win] += 1 - idx += 1 - return output / weights.clamp(min=1) + windows_flat = windows[:n_h * n_w].view(n_h * n_w, c * h_win * w_win) + output = nn.functional.fold( + windows_flat.transpose(0, 1), + output_size=(out_h, out_w), + kernel_size=(h_win, w_win), + stride=stride + ) + weights = nn.functional.fold( + torch.ones(c * h_win * w_win, n_h * n_w, device=first_window.device, dtype=torch.uint8), + output_size=(out_h, out_w), + kernel_size=(h_win, w_win), + stride=stride + ).sum(dim=0, keepdim=True) + output /= weights.clamp(min=1) + pad_h = out_h - original_shape[0] + pad_w = out_w - original_shape[1] + h_start = pad_h // 2 + w_start = pad_w // 2 + return output[:, h_start:h_start + original_shape[0], w_start:w_start + original_shape[1]] + d_win, h_win, w_win = window_shape + n_d, n_h, n_w = layout + out_d = (n_d - 1) * stride[0] + d_win + out_h = (n_h - 1) * stride[1] + h_win + out_w = (n_w - 1) * stride[2] + w_win + output = torch.zeros(c, out_d, out_h, out_w, device=first_window.device, dtype=first_window.dtype) + weights = torch.zeros(1, out_d, out_h, out_w, device=first_window.device, dtype=torch.uint8) + windows = windows[:n_d * n_h * n_w].view(n_d, n_h, n_w, c, d_win, h_win, w_win) + for i in range(n_d): + d_start = i * stride[0] + d_slice = slice(d_start, d_start + d_win) + for j in range(n_h): + h_start = j * stride[1] + h_slice = slice(h_start, h_start + h_win) + for k in range(n_w): + w_start = k * stride[2] + w_slice = slice(w_start, w_start + w_win) + output[:, d_slice, h_slice, w_slice] += windows[i, j, k] + weights[0, d_slice, h_slice, w_slice] += 1 + output /= weights.clamp(min=1) + pad_d = out_d - original_shape[0] + pad_h = out_h - original_shape[1] + pad_w = out_w - original_shape[2] + d_start = pad_d // 2 + h_start = pad_h // 2 + w_start = pad_w // 2 + return output[:, d_start:d_start + original_shape[0], h_start:h_start + original_shape[1], + w_start:w_start + original_shape[2]] + + +def _slide_internal(image: torch.Tensor, window_shape: Shape, overlap: float, i: int, ind: int, output_folder: str, *, + is_label: bool = False) -> None: + windows, layout, original_shape = do_sliding_window(image, window_shape, overlap=overlap) + jnd = int(log10(windows.shape[0])) + 1 + for j in range(windows.shape[0]): + path = f"{output_folder}/{"labels" if is_label else "images"}/{str(i).zfill(ind)}_{str(j).zfill(jnd)}" + fast_save(windows[j], f"{path}_{layout}_{original_shape}.pt" if j == 0 else f"{path}.pt") def _slide(supervised: bool, dataset: UnsupervisedDataset | SupervisedDataset, output_folder: str | PathLike[str], window_shape: Shape, *, overlap: float = .5, console: Console = Console()) -> None: makedirs(f"{output_folder}/images", exist_ok=True) - makedirs(f"{output_folder}/labels", exist_ok=True) + if supervised: + makedirs(f"{output_folder}/labels", exist_ok=True) ind = int(log10(len(dataset))) + 1 with Progress(console=console) as progress: task = progress.add_task("Sliding dataset...", total=len(dataset)) for i, case in enumerate(dataset): image = case[0] if supervised else case progress.update(task, description=f"Sliding dataset {tuple(image.shape)}...") - windows = do_sliding_window(image, window_shape, overlap=overlap) - jnd = int(log10(len(windows))) + 1 - for j, window in enumerate(windows): - fast_save(window, f"{output_folder}/images/{str(i).zfill(ind)}_{str(j).zfill(jnd)}.pt") + _slide_internal(image, window_shape, overlap, i, ind, output_folder) if supervised: label = case[1] - windows = do_sliding_window(label, window_shape, overlap=overlap) - for j, window in enumerate(windows): - fast_save(window, f"{output_folder}/labels/{str(i).zfill(ind)}_{str(j).zfill(jnd)}.pt") + _slide_internal(label, window_shape, overlap, i, ind, output_folder, is_label=True) progress.update(task, advance=1, description=f"Sliding dataset ({i + 1}/{len(dataset)})...") @@ -140,18 +147,57 @@ def slide_dataset(dataset: UnsupervisedDataset | SupervisedDataset, output_folde console=console) +@dataclass +class SWCase(object): + window_indices: list[int] + layout: Shape | None + original_shape: Shape | None + + class UnsupervisedSWDataset(TensorLoader, PathBasedUnsupervisedDataset): def __init__(self, folder: str | PathLike[str], *, subfolder: Literal["images", "labels"] = "images", transform: Transform | None = None, device: Device = "cpu") -> None: super().__init__(sorted(listdir(f"{folder}/{subfolder}")), transform=transform, device=device) self._folder: str = folder self._subfolder: Literal["images", "labels"] = subfolder + self._groups: list[SWCase] = [] + for idx, filename in enumerate(self._images): + meta = filename[:filename.rfind(".")].split("_") + case_id = int(meta[0]) + if case_id >= len(self._groups): + if case_id != len(self._groups): + raise ValueError(f"Mismatched case id {case_id}") + self._groups.append(SWCase([], None, None)) + self._groups[case_id].window_indices.append(idx) + if len(meta) == 4: + if self._groups[case_id].layout: + raise ValueError(f"Duplicated layout specification for case {case_id}") + self._groups[case_id].layout = literal_eval(meta[2]) + if self._groups[case_id].original_shape: + raise ValueError(f"Duplicated original shape specification for case {case_id}") + self._groups[case_id].original_shape = literal_eval(meta[3]) + for idx, case in enumerate(self._groups): + windows, layout, original_shape = case.window_indices, case.layout, case.original_shape + if not layout: + raise ValueError(f"Layout not specified for case {idx}") + if not original_shape: + raise ValueError(f"Original shape not specified for case {idx}") + if len(windows) != reduce(mul, layout): + raise ValueError(f"Mismatched number of windows {len(windows)} and layout {layout} for case {idx}") @override def load(self, idx: int) -> torch.Tensor: return self.do_load(f"{self._folder}/{self._subfolder}/{self._images[idx]}", is_label=self._subfolder == "labels", device=self._device) + def case_meta(self, case_idx: int) -> tuple[int, Shape, Shape]: + case = self._groups[case_idx] + return len(case.window_indices), case.layout, case.original_shape + + def case(self, case_idx: int, *, part: slice | None = None) -> torch.Tensor: + indices = self._groups[case_idx].window_indices + return torch.stack([self[idx] for idx in (indices[part] if part else indices)]) + class SupervisedSWDataset(TensorLoader, MergedDataset, SupervisedDataset[UnsupervisedSWDataset]): def __init__(self, folder: str | PathLike[str], *, transform: JointTransform | None = None, diff --git a/mipcandy/data/transform.py b/mipcandy/data/transform.py index 5f87d7a..7fa2f27 100644 --- a/mipcandy/data/transform.py +++ b/mipcandy/data/transform.py @@ -8,32 +8,32 @@ class JointTransform(nn.Module): def __init__(self, *, transform: Transform | None = None, image_only: Transform | None = None, label_only: Transform | None = None, keys: tuple[str, str] = ("image", "label")) -> None: super().__init__() - self._transform: Transform | None = transform - self._image_only: Transform | None = image_only - self._label_only: Transform | None = label_only + self.transform: Transform | None = transform + self.image_only: Transform | None = image_only + self.label_only: Transform | None = label_only self._keys: tuple[str, str] = keys def forward(self, image: torch.Tensor, label: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ik, lk = self._keys data = {ik: image, lk: label} - if self._transform: - data = self._transform(data) - if self._image_only: - data[ik] = self._image_only(data[ik]) - if self._label_only: - data[lk] = self._label_only(data[lk]) + if self.transform: + data = self.transform(data) + if self.image_only: + data[ik] = self.image_only(data[ik]) + if self.label_only: + data[lk] = self.label_only(data[lk]) return data[ik], data[lk] class MONAITransform(nn.Module): def __init__(self, transform: Transform, *, keys: tuple[str, str] = ("image", "label")) -> None: super().__init__() - self._transform: Transform = transform + self.transform: Transform = transform self._keys: tuple[str, str] = keys def forward(self, data: torch.Tensor | dict[str, torch.Tensor]) -> torch.Tensor | dict[str, torch.Tensor]: if isinstance(data, torch.Tensor): - return self._transform(data) + return self.transform(data) ik, lk = self._keys image, label = data[ik], data[lk] - return {ik: self._transform(image), lk: self._transform(label)} + return {ik: self.transform(image), lk: self.transform(label)} diff --git a/mipcandy/data/visualization.py b/mipcandy/data/visualization.py index 3c6c2c7..3052542 100644 --- a/mipcandy/data/visualization.py +++ b/mipcandy/data/visualization.py @@ -15,7 +15,7 @@ from mipcandy.data.geometric import ensure_num_dimensions -def visualize2d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray", +def visualize2d(image: torch.Tensor, *, title: str | None = None, cmap: str | None = None, is_label: bool = False, blocking: bool = False, screenshot_as: str | PathLike[str] | None = None) -> None: image = image.detach().cpu() if image.ndim < 2: @@ -28,6 +28,8 @@ def visualize2d(image: torch.Tensor, *, title: str | None = None, cmap: str = "g else: image = image.permute(1, 2, 0) image = auto_convert(image) + if not cmap: + cmap = "jet" if is_label else "gray" plt.imshow(image.numpy(), cmap, vmin=0, vmax=255) plt.title(title) plt.axis("off") @@ -50,10 +52,17 @@ def _visualize3d_with_pyvista(image: np.ndarray, title: str | None, cmap: str, p.show() -def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray", max_volume: int = 1e6, +__LABEL_COLORMAP: list[str] = [ + "#ffffff", "#2e4057", "#7a0f1c", "#004f4f", "#9a7b00", "#2c2f38", "#5c136f", "#113f2e", "#8a3b12", "#2b1a6f", + "#4a5a1a", "#006b6e", "#3b1f14", "#0a2c66", "#5a0f3c", "#0f5c3a" +] + + +def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str | list[str] | None = None, + max_volume: int = 1e6, is_label: bool = False, backend: Literal["auto", "matplotlib", "pyvista"] = "auto", blocking: bool = False, screenshot_as: str | PathLike[str] | None = None) -> None: - image = image.detach().float().cpu() + image = image.detach().cpu() if image.ndim < 3: raise ValueError(f"`image` must have at least 3 dimensions, got {image.shape}") if image.ndim > 3: @@ -62,11 +71,20 @@ def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "g total = d * h * w ratio = int(ceil((total / max_volume) ** (1 / 3))) if total > max_volume else 1 if ratio > 1: - image = ensure_num_dimensions(nn.functional.avg_pool3d(ensure_num_dimensions(image, 5), kernel_size=ratio, - stride=ratio, ceil_mode=True), 3) - image = image.numpy() + image = ensure_num_dimensions(nn.functional.avg_pool3d( + ensure_num_dimensions(image, 5).float(), kernel_size=ratio, stride=ratio, ceil_mode=True + ), 3).to(image.dtype) if backend == "auto": backend = "pyvista" if find_spec("pyvista") else "matplotlib" + if is_label: + max_id = image.max() + if max_id > 1 and torch.is_floating_point(image): + raise ValueError(f"Label must be class ids that are in [0, 1] or of integer type, got {image.dtype}") + if not cmap: + cmap = __LABEL_COLORMAP[:max_id + 1] if backend == "pyvista" and max_id < len(__LABEL_COLORMAP) else "jet" + elif not cmap: + cmap = "gray" + image = image.numpy() match backend: case "matplotlib": warn("Using Matplotlib for 3D visualization is inefficient and inaccurate, consider using PyVista") diff --git a/mipcandy/inference.py b/mipcandy/inference.py index 967a7ee..7da1bcb 100644 --- a/mipcandy/inference.py +++ b/mipcandy/inference.py @@ -50,7 +50,7 @@ def __init__(self, experiment_folder: str | PathLike[str], example_shape: Ambigu def lazy_load_model(self) -> None: if self._model: return - self._model = self.load_model(self._example_shape, False, checkpoint=torch.load( + self._model = self.load_model(self._example_shape, False, path=torch.load( f"{self._experiment_folder}/{self._checkpoint}" )) self._model.eval() diff --git a/mipcandy/layer.py b/mipcandy/layer.py index 3ac5d5b..63b5e09 100644 --- a/mipcandy/layer.py +++ b/mipcandy/layer.py @@ -1,7 +1,9 @@ from abc import ABCMeta, abstractmethod -from typing import Any, Generator, Self, Mapping +from os import PathLike +from typing import Any, Generator, Self, override import torch +from safetensors.torch import save_model, load_model from torch import nn from mipcandy.types import Device, AmbiguousShape @@ -50,15 +52,14 @@ def __init__(self, device: Device) -> None: def device(self, *, device: Device | None = None) -> None | Device: if device is None: return self._device - else: - self._device = device + self._device = device def auto_device() -> Device: if torch.cuda.is_available(): return f"cuda:{max(range(torch.cuda.device_count()), key=lambda i: torch.cuda.memory_reserved(i) - torch.cuda.memory_allocated(i))}" - if torch.backends.mps.is_available(): + if torch.mps.is_available(): return "mps" return "cpu" @@ -96,24 +97,51 @@ def get_restoring_module(self) -> nn.Module | None: return self._restoring_module -class WithNetwork(HasDevice, metaclass=ABCMeta): +class WithCheckpoint(object, metaclass=ABCMeta): + @abstractmethod + def load_checkpoint(self, model: nn.Module, path: str | PathLike[str]) -> nn.Module: + raise NotImplementedError + + @abstractmethod + def save_checkpoint(self, model: nn.Module, path: str | PathLike[str]) -> None: + raise NotImplementedError + + +class WithNetwork(WithCheckpoint, HasDevice, metaclass=ABCMeta): def __init__(self, device: Device) -> None: super().__init__(device) + @override + def load_checkpoint(self, model: nn.Module, path: str | PathLike[str]) -> nn.Module: + load_model(model, path) + return model + + @override + def save_checkpoint(self, model: nn.Module, path: str | PathLike[str]) -> None: + save_model(model, path) + @abstractmethod def build_network(self, example_shape: AmbiguousShape) -> nn.Module: raise NotImplementedError - def build_network_from_checkpoint(self, example_shape: AmbiguousShape, checkpoint: Mapping[str, Any]) -> nn.Module: + @staticmethod + def compile_model(model: nn.Module) -> nn.Module: + return torch.compile(model) + + def build_network_from_checkpoint(self, example_shape: AmbiguousShape, path: str | PathLike[str], + compile_model: bool) -> nn.Module: """ Internally exposed interface for overriding. Use `load_model()` instead. """ - network = self.build_network(example_shape) - network.load_state_dict(checkpoint) - return network + model = self.build_network(example_shape) + return self.load_checkpoint(self.compile_model(model) if compile_model else model, path) def load_model(self, example_shape: AmbiguousShape, compile_model: bool, *, - checkpoint: Mapping[str, Any] | None = None) -> nn.Module: - model = (self.build_network_from_checkpoint(example_shape, checkpoint) if checkpoint else self.build_network( - example_shape)).to(self._device) - return torch.compile(model) if compile_model else model + path: str | PathLike[str] | None = None) -> nn.Module: + if path: + return self.build_network_from_checkpoint(example_shape, path, compile_model).to(self._device) + model = self.build_network(example_shape).to(self._device) + return self.compile_model(model) if compile_model else model + + def save_model(self, model: nn.Module, path: str | PathLike[str]) -> None: + self.save_checkpoint(model, path) diff --git a/mipcandy/metrics.py b/mipcandy/metrics.py index 3b065ef..592af12 100644 --- a/mipcandy/metrics.py +++ b/mipcandy/metrics.py @@ -60,14 +60,19 @@ def dice_similarity_coefficient_multiclass(output: torch.Tensor, label: torch.Te return apply_multiclass_to_binary(dice_similarity_coefficient_binary, output, label, num_classes, if_empty) -def soft_dice_coefficient(output: torch.Tensor, label: torch.Tensor, *, - smooth: float = 1e-5, include_background: bool = True) -> torch.Tensor: +def soft_dice_coefficient(output: torch.Tensor, label: torch.Tensor, *, smooth: float = 1, + batch: bool = True) -> torch.Tensor: _args_check(output, label) axes = tuple(range(2, output.ndim)) - intersection = (output * label).sum(dim=axes) - dice = (2 * intersection + smooth) / (output.sum(dim=axes) + label.sum(dim=axes) + smooth) - if not include_background: - dice = dice[:, 1:] + with torch.no_grad(): + label_sum = label.sum(axes) + intersection = (output * label).sum(axes) + output_sum = output.sum(axes) + if batch: + intersection = intersection.sum(0) + output_sum = output_sum.sum(0) + label_sum = label_sum.sum(0) + dice = (2 * intersection + smooth) / (torch.clip(label_sum + output_sum + smooth, 1e-8)) return dice.mean() diff --git a/mipcandy/presets/segmentation.py b/mipcandy/presets/segmentation.py index eb2f3d6..294101f 100644 --- a/mipcandy/presets/segmentation.py +++ b/mipcandy/presets/segmentation.py @@ -1,38 +1,89 @@ from abc import ABCMeta -from collections import defaultdict -from typing import override +from typing import override, Callable, Sequence, Any +import numpy as np import torch from rich.progress import Progress, SpinnerColumn from torch import nn, optim -from mipcandy.common import AbsoluteLinearLR, DiceBCELossWithLogits -from mipcandy.data import visualize2d, visualize3d, overlay, auto_convert, convert_logits_to_ids, \ - revert_sliding_window, PathBasedSupervisedDataset, SupervisedSWDataset +from mipcandy.common import PolyLRScheduler, DiceBCELossWithLogits, DiceCELossWithLogits +from mipcandy.data import visualize2d, visualize3d, overlay, auto_convert, convert_logits_to_ids, SupervisedDataset, \ + revert_sliding_window, SupervisedSWDataset, fast_save from mipcandy.training import Trainer, TrainerToolbox, try_append_all -from mipcandy.types import Params +from mipcandy.types import Params, Shape + + +def print_stats_of_class_ids(x: torch.Tensor, name: str, num_classes: int) -> None: + print(f"{name} unique", x.unique()) + binc_p = torch.bincount(x.flatten(), minlength=num_classes) + print(f"{name} class distribution:", (binc_p / binc_p.sum()).cpu().tolist()) + + +class DeepSupervisionWrapper(nn.Module): + def __init__(self, loss: nn.Module, *, weight_factors: Sequence[float] | None = None) -> None: + super().__init__() + if weight_factors and all(x == 0 for x in weight_factors): + raise ValueError("At least one weight factor should be nonzero") + self.weight_factors: tuple[float, ...] = tuple(weight_factors) + self.loss: nn.Module = loss + + @override + def __getattr__(self, item: str) -> Any: + return self.loss.validation_mode if item == "validation_mode" else super().__getattr__(item) + + @override + def __setattr__(self, name: str, value: Any) -> None: + if name == "validation_mode" and hasattr(self.loss, "validation_mode"): + self.loss.validation_mode = value + super().__setattr__(name, value) + + def forward(self, outputs: Sequence[torch.Tensor], targets: Sequence[torch.Tensor]) -> tuple[ + torch.Tensor, dict[str, float]]: + if not self.weight_factors: + weights = (1.0,) * len(outputs) + else: + weights = self.weight_factors + total_loss = torch.tensor(0, device=outputs[0].device, dtype=outputs[0].dtype) + combined_metrics = {} + for i, (output, target) in enumerate(zip(outputs, targets)): + if weights[i] == 0: + continue + loss, metrics = self.loss(output, target) + total_loss += weights[i] * loss + for key, value in metrics.items(): + metric_key = f"{key}_ds{i}" if len(outputs) > 1 else key + combined_metrics[metric_key] = value + if combined_metrics: + main_loss_key = next(iter(combined_metrics.keys())).replace("_ds0", "") + if f"{main_loss_key}_ds0" in combined_metrics: + combined_metrics[main_loss_key] = combined_metrics[f"{main_loss_key}_ds0"] + return total_loss, combined_metrics class SegmentationTrainer(Trainer, metaclass=ABCMeta): num_classes: int = 1 include_background: bool = True + deep_supervision: bool = False + deep_supervision_scales: Sequence[float] | None = None + deep_supervision_weights: Sequence[float] | None = None - def _save_preview(self, x: torch.Tensor, title: str, quality: float) -> None: + def _save_preview(self, x: torch.Tensor, title: str, quality: float, *, is_label: bool = False) -> None: path = f"{self.experiment_folder()}/{title} (preview).png" if x.ndim == 3 and x.shape[0] in (1, 3, 4): - visualize2d(auto_convert(x), title=title, blocking=True, screenshot_as=path) + visualize2d(auto_convert(x), title=title, is_label=is_label, blocking=True, screenshot_as=path) elif x.ndim == 4 and x.shape[0] == 1: - visualize3d(x, title=title, max_volume=int(quality * 1e6), blocking=True, screenshot_as=path) + visualize3d(x, title=title, max_volume=int(quality * 1e6), is_label=is_label, blocking=True, + screenshot_as=path) @override def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.Tensor, *, quality: float = .75) -> None: - output = output.sigmoid() + output = output.sigmoid() if self.num_classes < 2 else output.softmax(0) if output.shape[0] != 1: - output = convert_logits_to_ids(output.unsqueeze(0)).squeeze(0) + output = convert_logits_to_ids(output, channel_dim=0).int() self._save_preview(image, "input", quality) - self._save_preview(label, "label", quality) - self._save_preview(output, "prediction", quality) + self._save_preview(label.int(), "label", quality, is_label=True) + self._save_preview(output, "prediction", quality, is_label=True) if image.ndim == label.ndim == output.ndim == 3 and label.shape[0] == output.shape[0] == 1: visualize2d(overlay(image, label), title="expected", blocking=True, screenshot_as=f"{self.experiment_folder()}/expected (preview).png") @@ -45,96 +96,193 @@ def build_ema(self, model: nn.Module) -> nn.Module: @override def build_criterion(self) -> nn.Module: - return DiceBCELossWithLogits(self.num_classes, include_background=self.include_background) + if self.num_classes < 2: + loss = DiceBCELossWithLogits(include_background=self.include_background) + else: + loss = DiceCELossWithLogits(self.num_classes, include_background=self.include_background) + if self.deep_supervision: + if not self.deep_supervision_weights and self.deep_supervision_scales: + weights = np.array([1 / (2 ** i) for i in range(len(self.deep_supervision_scales))]) + weights = weights / weights.sum() + self.deep_supervision_weights = tuple(weights.tolist()) + loss = DeepSupervisionWrapper(loss, weight_factors=self.deep_supervision_weights) + self.log(f"Deep supervision enabled with weights: {self.deep_supervision_weights}") + return loss @override def build_optimizer(self, params: Params) -> optim.Optimizer: - return optim.AdamW(params) + return optim.SGD(params, 1e-2, weight_decay=3e-5, momentum=.99, nesterov=True) @override def build_scheduler(self, optimizer: optim.Optimizer, num_epochs: int) -> optim.lr_scheduler.LRScheduler: - return AbsoluteLinearLR(optimizer, -8e-6 / len(self._dataloader), 1e-2) + return PolyLRScheduler(optimizer, 1e-2, num_epochs * len(self._dataloader)) @override def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ str, float]]: - masks = toolbox.model(images) - loss, metrics = toolbox.criterion(masks, labels) + outputs = toolbox.model(images) + if self.deep_supervision and outputs.ndim == labels.ndim + 1: + masks_list = list(torch.unbind(outputs, dim=1)) + targets = self.prepare_deep_supervision_targets(labels, [m.shape[2:] for m in masks_list]) + loss, metrics = toolbox.criterion(masks_list, targets) + elif self.deep_supervision and isinstance(outputs, (list, tuple)): + targets = self.prepare_deep_supervision_targets(labels, [m.shape[2:] for m in outputs]) + loss, metrics = toolbox.criterion(outputs, targets) + else: + with torch.no_grad(): + print_stats_of_class_ids(labels, "label", self.num_classes) + preds = outputs.softmax(1) + preds = convert_logits_to_ids(preds) + print_stats_of_class_ids(preds, "prediction", self.num_classes) + print("=====" * 10) + loss, metrics = toolbox.criterion(outputs, labels) loss.backward() + nn.utils.clip_grad_norm_(toolbox.model.parameters(), 12) return loss.item(), metrics + @staticmethod + def prepare_deep_supervision_targets(labels: torch.Tensor, output_shapes: list[tuple[int, ...]]) -> list[ + torch.Tensor]: + targets = [] + for shape in output_shapes: + if labels.shape[2:] == shape: + targets.append(labels) + else: + downsampled = nn.functional.interpolate(labels.float(), shape, + mode="nearest-exact" if labels.ndim == 4 else "nearest") + targets.append(downsampled if labels.dtype == torch.float32 else downsampled.to(labels.dtype)) + return targets + @override - def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ - str, float], torch.Tensor]: + def validate_case(self, idx: int, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[ + float, dict[str, float], torch.Tensor]: image, label = image.unsqueeze(0), label.unsqueeze(0) - mask = (toolbox.ema if toolbox.ema else toolbox.model)(image) - loss, metrics = toolbox.criterion(mask, label) - return -loss.item(), metrics, mask.squeeze(0) + output = (toolbox.ema if toolbox.ema else toolbox.model)(image) + # (B, N, C, H, W, D) with the highest resolution is at index 0 + if self.deep_supervision and output.ndim == label.ndim + 1: + mask_for_loss = output[:, 0] + mask_output = output[:, 0] + elif self.deep_supervision and isinstance(output, (list, tuple)): + mask_for_loss = output[0] + mask_output = output[0] + else: + mask_for_loss = output + mask_output = output + if hasattr(toolbox.criterion, "validation_mode"): + toolbox.criterion.validation_mode = True + if self.deep_supervision and isinstance(toolbox.criterion, DeepSupervisionWrapper): + loss, metrics = toolbox.criterion([mask_for_loss], [label]) + else: + loss, metrics = toolbox.criterion(mask_for_loss, label) + if hasattr(toolbox.criterion, "validation_mode"): + toolbox.criterion.validation_mode = False + self.log(f"Metrics for case {idx}: {metrics}") + return -loss.item(), metrics, mask_output.squeeze(0) class SlidingTrainer(SegmentationTrainer, metaclass=ABCMeta): overlap: float = .5 - _validation_dataset: PathBasedSupervisedDataset | None = None + window_batch_size: int = 1 + full_validation_at_epochs: list[Callable[[int], int]] = [lambda num_epochs: num_epochs - 1] + compute_loss_on_device: bool = False + _full_validation_dataset: SupervisedDataset | None = None _slided_validation_dataset: SupervisedSWDataset | None = None - def set_validation_datasets(self, dataset: PathBasedSupervisedDataset, slided_dataset: SupervisedSWDataset) -> None: - self._validation_dataset = dataset - self._slided_validation_dataset = slided_dataset + def set_datasets(self, full_dataset: SupervisedDataset, slided_dataset: SupervisedSWDataset) -> None: + self.set_full_validation_dataset(full_dataset) + self.set_slided_validation_dataset(slided_dataset) + + def set_full_validation_dataset(self, dataset: SupervisedDataset) -> None: + dataset.device(device=self._device if self.compute_loss_on_device else "cpu") + self._full_validation_dataset = dataset + + def full_validation_dataset(self) -> SupervisedDataset: + if self._full_validation_dataset: + return self._full_validation_dataset + raise ValueError("Full validation dataset is not set") - def validation_dataset(self) -> PathBasedSupervisedDataset: - if self._validation_dataset: - return self._validation_dataset - raise ValueError("Validation datasets are not set") + def set_slided_validation_dataset(self, dataset: SupervisedSWDataset) -> None: + self._slided_validation_dataset = dataset def slided_validation_dataset(self) -> SupervisedSWDataset: if self._slided_validation_dataset: return self._slided_validation_dataset - raise ValueError("Validation datasets are not set") + raise ValueError("Slided validation dataset is not set") @override def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float]]]: - validation_dataset = self.validation_dataset() - slided_validation_dataset = self.slided_validation_dataset() - image_files = slided_validation_dataset.images().paths() - groups = defaultdict(list) - for idx, filename in enumerate(image_files): - case_id = filename.split("_")[0] - groups[case_id].append(idx) + if self._tracker.epoch not in self.full_validation_at_epochs: + return super().validate(toolbox) + self.log("Performing full-resolution validation") + return self.fully_validate(toolbox) + + def fully_validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float]]]: + self.record_profiler_linebreak(f"Fully validating epoch {self._tracker.epoch}") + self.record_profiler() + self.record_profiler_linebreak("Emptying cache") + self.empty_cache() + self.record_profiler() toolbox.model.eval() if toolbox.ema: toolbox.ema.eval() score = 0 worst_score = float("+inf") metrics = {} - num_cases = len(groups) + num_cases = len(self._full_validation_dataset) with torch.no_grad(), Progress( *Progress.get_default_columns(), SpinnerColumn(), console=self._console ) as progress: - val_prog = progress.add_task("Validating", total=num_cases) - for case_idx, case_id in enumerate(sorted(groups.keys())): - patches = [slided_validation_dataset[idx][0].to(self._device) for idx in groups[case_id]] - label = validation_dataset[case_idx][1].to(self._device) - progress.update(val_prog, description=f"Validating case {case_id} ({len(patches)} patches)") - case_score, case_metrics, output = self.validate_case(patches, label, toolbox) + task = progress.add_task(f"Fully validating", total=num_cases) + for idx in range(num_cases): + progress.update(task, description=f"Validating epoch {self._tracker.epoch} case {idx}") + case_score, case_metrics, output = self.fully_validate_case(idx, toolbox) + self.record_profiler() + self.record_profiler_linebreak("Emptying cache") + self.empty_cache() + self.record_profiler() score += case_score if case_score < worst_score: - self._tracker.worst_case = (validation_dataset[case_idx][0], label, output) + self._tracker.worst_case = idx + fast_save(output, f"{self.experiment_folder()}/worst_full_output.pt") worst_score = case_score try_append_all(case_metrics, metrics) - progress.update(val_prog, advance=1, description=f"Validating ({case_score:.4f})") + progress.update(task, advance=1, + description=f"Validating epoch {self._tracker.epoch} case {idx} ({case_score:.4f})") + self.record_profiler() return score / num_cases, metrics - @override - def validate_case(self, patches: list[torch.Tensor], label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[ - float, dict[str, float], torch.Tensor]: + def infer_validation_case(self, idx: int, toolbox: TrainerToolbox) -> tuple[torch.Tensor, Shape, Shape]: model = toolbox.ema if toolbox.ema else toolbox.model - outputs = [] - for patch in patches: - outputs.append(model(patch.unsqueeze(0)).squeeze(0)) - reconstructed = revert_sliding_window(outputs, overlap=self.overlap) - pad = [] - for r, l in zip(reversed(reconstructed.shape[2:]), reversed(label.shape[1:])): - pad.extend([0, r - l]) - label = nn.functional.pad(label, pad) - loss, metrics = toolbox.criterion(reconstructed, label.unsqueeze(0)) - return -loss.item(), metrics, reconstructed.squeeze(0) + images = self.slided_validation_dataset().images() + num_windows, layout, original_shape = images.case_meta(idx) + canvas = None + for i in range(0, num_windows, self.window_batch_size): + end = min(i + self.window_batch_size, num_windows) + outputs = model(images.case(idx, part=slice(i, end)).to(self._device)) + + # For deep supervision, only use the highest resolution output + # DynUNet stacks as (B, N, C, H, W, D), extract first output at dim 1 + if self.deep_supervision and outputs.ndim == 6: # 6D tensor for 3D images with deep supervision + outputs = outputs[:, 0] # Extract (B, C, H, W, D) + elif self.deep_supervision and outputs.ndim == 5: # 5D tensor for 2D images with deep supervision + outputs = outputs[:, 0] # Extract (B, C, H, W) + elif self.deep_supervision and isinstance(outputs, (list, tuple)): + outputs = outputs[0] + + if canvas is None: + canvas = torch.empty((num_windows, *outputs.shape[1:]), dtype=outputs.dtype, device=self._device) + canvas[i:end] = outputs + return canvas, layout, original_shape + + def fully_validate_case(self, idx: int, toolbox: TrainerToolbox) -> tuple[ + float, dict[str, float], torch.Tensor]: + windows, layout, original_shape = self.infer_validation_case(idx, toolbox) + self.empty_cache() + reconstructed = revert_sliding_window(windows, layout, original_shape, overlap=self.overlap) + if self.compute_loss_on_device: + self.empty_cache() + else: + reconstructed = reconstructed.cpu() + label = self._full_validation_dataset.label(idx) + loss, metrics = toolbox.criterion(reconstructed.unsqueeze(0), label.unsqueeze(0)) + return -loss.item(), metrics, reconstructed diff --git a/mipcandy/profiler.py b/mipcandy/profiler.py new file mode 100644 index 0000000..d16a9de --- /dev/null +++ b/mipcandy/profiler.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass +from inspect import stack +from os import PathLike +from time import time +from typing import Sequence, override + +import torch +from psutil import cpu_percent, virtual_memory + +from mipcandy.types import Device + + +@dataclass +class ProfilerFrame(object): + stack: str + cpu: float + mem: float + gpu: list[float] | None = None + gpu_mem: list[float] | None = None + + @override + def __str__(self) -> str: + r = f"[{self.stack}] CPU: {self.cpu:.2f}% @ Memory: {self.mem:.2f}%\n" + if self.gpu and self.gpu_mem: + for i, gpu in enumerate(self.gpu): + r += f"\t\tGPU {i}: {gpu:.2f}% @ Memory: {self.gpu_mem[i]:.2f}%\n" + return r + + def export(self, duration: float) -> str: + return f"{duration:.2f}s\t{self}" + + +class _LineBreak(object): + def __init__(self, message: str) -> None: + self.message: str = message + + @override + def __str__(self) -> str: + return f"<{self.message}>\n" + + def export(self, duration: float) -> str: + return f"{duration:.2f}s\t{self}" + + +class Profiler(object): + def __init__(self, title: str, save_as: str | PathLike[str], *, gpus: Sequence[Device] = ()) -> None: + self.title: str = title + self.save_as: str = save_as + self.total_mem: float = self.get_total_mem() + self.has_gpu: bool = len(gpus) > 0 + self._gpus: Sequence[Device] = gpus + self.total_gpu_mem: list[float] = [self.get_total_gpu_mem(device) for device in gpus] + with open(save_as, "w") as f: + f.write(f"# {title}\nTotal memory: {self.total_mem}, Total GPU memory: {self.total_gpu_mem}\n\n") + self._t0: float = time() + + @staticmethod + def get_cpu_usage() -> float: + return cpu_percent() + + def get_mem_usage(self) -> float: + return 100 * virtual_memory().used / self.total_mem + + @staticmethod + def get_total_mem() -> float: + return virtual_memory().total + + @staticmethod + def get_gpu_usage(device: Device) -> float: + return torch.cuda.utilization(device) + + def get_gpu_mem_usage(self, device: Device) -> float: + return 100 * torch.cuda.device_memory_used(device) / self.total_gpu_mem[self._gpus.index(device)] + + @staticmethod + def get_total_gpu_mem(device: Device) -> float: + return torch.cuda.get_device_properties(device).total_memory + + def _save(self, obj: ProfilerFrame | _LineBreak) -> None: + with open(self.save_as, "a") as f: + t = time() + f.write(f"{obj.export(t - self._t0)}\n") + self._t0 = t + + def record(self, *, stack_trace_offset: int = 1) -> ProfilerFrame: + frame = ProfilerFrame(" -> ".join([f"{f.function}:{f.lineno}" for f in reversed(stack()[stack_trace_offset:])]), + self.get_cpu_usage(), self.get_mem_usage()) + if self.has_gpu: + frame.gpu = [torch.cuda.utilization(device) for device in self._gpus] + frame.gpu_mem = [self.get_gpu_mem_usage(device) for device in self._gpus] + self._save(frame) + return frame + + def line_break(self, message: str) -> None: + self._save(_LineBreak(message)) diff --git a/mipcandy/sanity_check.py b/mipcandy/sanity_check.py index 274e8f4..94fbae1 100644 --- a/mipcandy/sanity_check.py +++ b/mipcandy/sanity_check.py @@ -35,9 +35,10 @@ def __str__(self) -> str: def sanity_check(model: nn.Module, input_shape: Sequence[int], *, device: Device | None = None) -> SanityCheckResult: if device is None: device = auto_device() - num_macs, num_params, layer_stats = model_complexity_info(model, input_shape) - if num_macs is None or num_params is None: - raise RuntimeError("Failed to validate model") - outputs = model.to(device).eval()(torch.randn(1, *input_shape, device=device)) + with torch.no_grad(): + num_macs, num_params, layer_stats = model_complexity_info(model, input_shape) + if num_macs is None or num_params is None: + raise RuntimeError("Failed to validate model") + outputs = model.to(device).eval()(torch.randn(1, *input_shape, device=device)) return SanityCheckResult(num_macs, num_params, layer_stats, ( outputs[0] if isinstance(outputs, tuple) else outputs).squeeze(0)) diff --git a/mipcandy/training.py b/mipcandy/training.py index b94975d..1affed6 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, asdict from datetime import datetime from hashlib import md5 from json import load, dump @@ -9,7 +9,7 @@ from shutil import copy from threading import Lock from time import time -from typing import Sequence, override, Callable, Self +from typing import Sequence, override, Self import numpy as np import torch @@ -23,8 +23,10 @@ from mipcandy.common import quotient_regression, quotient_derivative, quotient_bounds from mipcandy.config import load_settings, load_secrets +from mipcandy.data import fast_save, fast_load, empty_cache from mipcandy.frontend import Frontend from mipcandy.layer import WithPaddingModule, WithNetwork +from mipcandy.profiler import Profiler from mipcandy.sanity_check import sanity_check from mipcandy.types import Params, Setting, AmbiguousShape @@ -54,13 +56,13 @@ class TrainerToolbox(object): class TrainerTracker(object): epoch: int = 0 best_score: float = float("-inf") - worst_case: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None + worst_case: int | None = None class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta): def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, recoverable: bool = True, - device: torch.device | str = "cpu", console: Console = Console()) -> None: + profiler: bool = False, device: torch.device | str = "cpu", console: Console = Console()) -> None: WithPaddingModule.__init__(self, device) WithNetwork.__init__(self, device) self._trainer_folder: str = trainer_folder @@ -71,10 +73,11 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t self._unrecoverable: bool | None = not recoverable # None if the trainer is recovered self._console: Console = console self._metrics: dict[str, list[float]] = {} - self._epoch_metrics: dict[str, list[float]] = {} self._frontend: Frontend = Frontend({}) self._lock: Lock = Lock() self._tracker: TrainerTracker = TrainerTracker() + self._profiler: Profiler | None = None + self._use_profiler: bool = profiler # Recovery methods (PR #108 at https://github.com/ProjectNeura/MIPCandy/pull/108) @@ -82,19 +85,23 @@ def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: Trainer **training_arguments) -> None: if self._unrecoverable: return - torch.save(toolbox.optimizer.state_dict(), f"{self.experiment_folder()}/optimizer.pth") - torch.save(toolbox.scheduler.state_dict(), f"{self.experiment_folder()}/scheduler.pth") - torch.save(toolbox.criterion.state_dict(), f"{self.experiment_folder()}/criterion.pth") - torch.save(tracker, f"{self.experiment_folder()}/tracker.pt") - with open(f"{self.experiment_folder()}/training_arguments.json", "w") as f: - dump(training_arguments, f) + torch.save({ + "optimizer": toolbox.optimizer.state_dict(), + "scheduler": toolbox.scheduler.state_dict(), + "criterion": toolbox.criterion.state_dict() + }, f"{self.experiment_folder()}/state_dicts.pth") + with open(f"{self.experiment_folder()}/state_orb.json", "w") as f: + dump({"tracker": asdict(tracker), "training_arguments": training_arguments}, f) + + def load_state_orb(self) -> dict[str, dict[str, Setting]]: + with open(f"{self.experiment_folder()}/state_orb.json") as f: + return load(f) def load_tracker(self) -> TrainerTracker: - return torch.load(f"{self.experiment_folder()}/tracker.pt", weights_only=False) + return TrainerTracker(**self.load_state_orb()["tracker"]) def load_training_arguments(self) -> dict[str, Setting]: - with open(f"{self.experiment_folder()}/training_arguments.json") as f: - return load(f) + return self.load_state_orb()["training_arguments"] def load_metrics(self) -> dict[str, list[float]]: df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch") @@ -103,11 +110,12 @@ def load_metrics(self) -> dict[str, list[float]]: def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_model: bool, ema: bool) -> TrainerToolbox: toolbox = self._build_toolbox(num_epochs, example_shape, compile_model, ema, model=self.load_model( - example_shape, compile_model, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") + example_shape, compile_model, path=f"{self.experiment_folder()}/checkpoint_latest.pth" )) - toolbox.optimizer.load_state_dict(torch.load(f"{self.experiment_folder()}/optimizer.pth")) - toolbox.scheduler.load_state_dict(torch.load(f"{self.experiment_folder()}/scheduler.pth")) - toolbox.criterion.load_state_dict(torch.load(f"{self.experiment_folder()}/criterion.pth")) + state_dicts = torch.load(f"{self.experiment_folder()}/state_dicts.pth") + toolbox.optimizer.load_state_dict(state_dicts["optimizer"]) + toolbox.scheduler.load_state_dict(state_dicts["scheduler"]) + toolbox.criterion.load_state_dict(state_dicts["criterion"]) return toolbox def recover_from(self, experiment_id: str) -> Self: @@ -224,7 +232,10 @@ def init_experiment(self) -> None: with open(f"{experiment_folder}/logs.txt", "w") as f: f.write(f"File created by FightTumor, copyright (C) {t.year} Project Neura. All rights reserved\n") self.log(f"Experiment (ID {self._experiment_id}) created at {t}") - self.log(f"Trainer: {self.__class__.__name__}") + self.log(f"Trainer: {self._trainer_variant}") + if self._use_profiler: + gpus = (self._device,) if torch.device(self._device).type == "cuda" else () + self._profiler = Profiler(self._trainer_variant, f"{experiment_folder}/profiler.txt", gpus=gpus) # Logging utilities @@ -238,19 +249,19 @@ def log(self, msg: str, *, on_screen: bool = True) -> None: self._console.print(msg) def record(self, metric: str, value: float) -> None: - try_append(value, self._epoch_metrics, metric) - - def _record(self, metric: str, value: float) -> None: try_append(value, self._metrics, metric) - def record_all(self, metrics: dict[str, float]) -> None: - try_append_all(metrics, self._epoch_metrics) + def record_all(self, metrics: dict[str, list[float]]) -> None: + try_append_all({k: sum(v) / len(v) for k, v in metrics.items()}, self._metrics) + + def record_profiler(self) -> None: + if self._profiler: + self._profiler.record(stack_trace_offset=2) - def _bump_metrics(self) -> None: - for metric, values in self._epoch_metrics.items(): - epoch_overall = sum(values) / len(values) - try_append(epoch_overall, self._metrics, metric) - self._epoch_metrics.clear() + def record_profiler_linebreak(self, message: str) -> None: + if self._profiler: + self._profiler.line_break(message) + self.log(f"[PROFILER] {message}") def save_metrics(self) -> None: df = DataFrame(self._metrics) @@ -293,10 +304,8 @@ def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.T quality: float = .75) -> None: ... - def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = None, prefix: str = "training", - epochwise: bool = True, skip: Callable[[str, list[float]], bool] | None = None) -> None: - if not metrics: - metrics = self._metrics + def show_metrics(self, epoch: int, metrics: dict[str, list[float]], prefix: str, *, epochwise: bool = True, + lookup_prefix: str = "", global_previous_index: int = -2) -> None: prefix = prefix.capitalize() table = Table(title=f"Epoch {epoch} {prefix}") table.add_column("Metric") @@ -304,16 +313,18 @@ def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = N table.add_column("Span", style="cyan") table.add_column("Diff", style="magenta") for metric, values in metrics.items(): - if skip and skip(metric, values): - continue span = f"[{min(values):.4f}, {max(values):.4f}]" if epochwise: - value = f"{values[-1]:.4f}" - diff = f"{values[-1] - values[-2]:+.4f}" if len(values) > 1 else "N/A" - else: + if global_previous_index >= 0: + raise ValueError("`global_previous_index` must be negative`") mean = sum(values) / len(values) value = f"{mean:.4f}" - diff = f"{mean - self._metrics[metric][-1]:+.4f}" if metric in self._metrics else "N/A" + m = f"{lookup_prefix}{metric}" + diff = f"{mean - self._metrics[m][global_previous_index]:+.4f}" if m in self._metrics and len( + self._metrics[m]) >= -global_previous_index else "N/A" + else: + value = f"{values[-1]:.4f}" + diff = f"{values[-1] - values[-2]:+.4f}" if len(values) > 1 else "N/A" table.add_row(metric, value, span, diff) self.log(f"{prefix} {metric}: {value} @{span} ({diff})") console = Console() @@ -350,6 +361,11 @@ def build_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_ ema: bool) -> TrainerToolbox: return self._build_toolbox(num_epochs, example_shape, compile_model, ema) + # Performance + + def empty_cache(self) -> None: + empty_cache(self._device) + # Training methods @abstractmethod @@ -367,23 +383,30 @@ def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: Train toolbox.ema.update_parameters(toolbox.model) return loss, metrics - def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None: + def train_epoch(self, toolbox: TrainerToolbox) -> dict[str, list[float]]: + self.record_profiler_linebreak(f"Epoch {self._tracker.epoch} training") + self.record_profiler() + self.record_profiler_linebreak("Emptying cache") + self.empty_cache() + self.record_profiler() toolbox.model.train() if toolbox.ema: toolbox.ema.train() + metrics = {} with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=self._console) as progress: - epoch_prog = progress.add_task(f"Epoch {epoch}", total=len(self._dataloader)) + task = progress.add_task(f"Epoch {self._tracker.epoch}", total=len(self._dataloader)) for images, labels in self._dataloader: - images, labels = images.to(self._device), labels.to(self._device) + images, labels = images.to(self._device, non_blocking=True), labels.to(self._device, non_blocking=True) padding_module = self.get_padding_module() if padding_module: images, labels = padding_module(images), padding_module(labels) - progress.update(epoch_prog, description=f"Training epoch {epoch} {tuple(images.shape)}") - loss, metrics = self.train_batch(images, labels, toolbox) - self.record("combined loss", loss) - self.record_all(metrics) - progress.update(epoch_prog, advance=1, description=f"Training epoch {epoch} ({loss:.4f})") - self._bump_metrics() + progress.update(task, description=f"Training epoch {self._tracker.epoch} {tuple(images.shape)}") + loss, batch_metrics = self.train_batch(images, labels, toolbox) + try_append(loss, metrics, "combined loss") + try_append_all(batch_metrics, metrics) + progress.update(task, advance=1, description=f"Training epoch {self._tracker.epoch} ({loss:.4f})") + self.record_profiler() + return metrics def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, compile_model: bool = True, ema: bool = True, seed: int | None = None, early_stop_tolerance: int = 5, @@ -396,6 +419,8 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co if seed is None: seed = randint(0, 100) self.set_seed(seed) + self.record_profiler() + self.record_profiler_linebreak("Sanity check") example_input = self.get_example_input().to(self._device).unsqueeze(0) padding_module = self.get_padding_module() if padding_module: @@ -409,6 +434,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co sanity_check_result = sanity_check(template_model, example_shape, device=self._device) self.log(str(sanity_check_result)) self.log(f"Example output shape: {tuple(sanity_check_result.output.shape)}") + self.record_profiler() self.log("Building toolbox...") toolbox = (self.load_toolbox if self.recovery() else self.build_toolbox)( num_epochs, example_shape, compile_model, ema @@ -418,6 +444,8 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note, sanity_check_result.num_macs, sanity_check_result.num_params, num_epochs, early_stop_tolerance) + del sanity_check_result, template_model, example_input + self.empty_cache() try: for epoch in range(self._tracker.epoch, self._tracker.epoch + num_epochs): if early_stop_tolerance == -1: @@ -428,18 +456,20 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co self._tracker.epoch = epoch # Training t0 = time() - self.train_epoch(epoch, toolbox) + metrics = self.train_epoch(toolbox) + self.record_all(metrics) lr = toolbox.scheduler.get_last_lr()[0] - self._record("learning rate", lr) - self.show_metrics(epoch, skip=lambda m, _: m.startswith("val ") or m == "epoch duration") - torch.save(toolbox.model.state_dict(), checkpoint_path("latest")) + self.record("learning rate", lr) + self.show_metrics(epoch, metrics, "training") + self.save_model(toolbox.model, checkpoint_path("latest")) if epoch % (num_epochs / num_checkpoints) == 0: copy(checkpoint_path("latest"), checkpoint_path(epoch)) self.log(f"Epoch {epoch} checkpoint saved") self.log(f"Epoch {epoch} training completed in {time() - t0:.1f} seconds") # Validation score, metrics = self.validate(toolbox) - self._record("val score", score) + self.record_all({f"val {k}": v for k, v in metrics.items()}) + self.record("val score", score) msg = f"Validation score: {score:.4f}" if epoch > 1: msg += f" ({score - self._metrics["val score"][-2]:+.4f})" @@ -452,7 +482,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co etc = self.etc(epoch, num_epochs, target_epoch=target_epoch) self.log(f"Estimated time of completion in {etc:.1f} seconds at {datetime.fromtimestamp( time() + etc):%m-%d %H:%M:%S}") - self.show_metrics(epoch, metrics=metrics, prefix="validation", epochwise=False) + self.show_metrics(epoch, metrics, "validation", lookup_prefix="val ") if score > self._tracker.best_score: copy(checkpoint_path("latest"), checkpoint_path("best")) self.log(f"======== Best checkpoint updated ({self._tracker.best_score:.4f} -> { @@ -460,11 +490,14 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co self._tracker.best_score = score early_stop_tolerance = es_tolerance if save_preview: - self.save_preview(*self._tracker.worst_case, quality=preview_quality) + self.save_preview( + *self._validation_dataloader.dataset[self._tracker.worst_case], + fast_load(f"{self.experiment_folder()}/worst_output.pt"), quality=preview_quality + ) else: early_stop_tolerance -= 1 epoch_duration = time() - t0 - self._record("epoch duration", epoch_duration) + self.record("epoch duration", epoch_duration) self.log(f"Epoch {epoch} completed in {epoch_duration:.1f} seconds") self.log(f"=============== Best Validation Score {self._tracker.best_score:.4f} ===============") self.save_metrics() @@ -496,13 +529,18 @@ def train_with_settings(self, num_epochs: int, **kwargs) -> None: # Validation methods @abstractmethod - def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ - str, float], torch.Tensor]: + def validate_case(self, idx: int, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[ + float, dict[str, float], torch.Tensor]: raise NotImplementedError def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float]]]: if self._validation_dataloader.batch_size != 1: raise RuntimeError("Validation dataloader should have batch size 1") + self.record_profiler_linebreak(f"Validating epoch {self._tracker.epoch}") + self.record_profiler() + self.record_profiler_linebreak("Emptying cache") + self.empty_cache() + self.record_profiler() toolbox.model.eval() if toolbox.ema: toolbox.ema.eval() @@ -513,21 +551,25 @@ def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float with torch.no_grad(), Progress( *Progress.get_default_columns(), SpinnerColumn(), console=self._console ) as progress: - val_prog = progress.add_task(f"Validating", total=num_cases) - for image, label in self._validation_dataloader: - image, label = image.to(self._device), label.to(self._device) + task = progress.add_task(f"Validating", total=num_cases) + for idx, (image, label) in enumerate(self._validation_dataloader): + image, label = image.to(self._device, non_blocking=True), label.to(self._device, non_blocking=True) padding_module = self.get_padding_module() if padding_module: image, label = padding_module(image), padding_module(label) image, label = image.squeeze(0), label.squeeze(0) - progress.update(val_prog, description=f"Validating {tuple(image.shape)}") - case_score, case_metrics, output = self.validate_case(image, label, toolbox) + progress.update(task, + description=f"Validating epoch {self._tracker.epoch} case {idx} {tuple(image.shape)}") + case_score, case_metrics, output = self.validate_case(idx, image, label, toolbox) score += case_score if case_score < worst_score: - self._tracker.worst_case = (image, label, output) + self._tracker.worst_case = idx + fast_save(output, f"{self.experiment_folder()}/worst_output.pt") worst_score = case_score try_append_all(case_metrics, metrics) - progress.update(val_prog, advance=1, description=f"Validating ({case_score:.4f})") + progress.update(task, advance=1, + description=f"Validating epoch {self._tracker.epoch} case {idx} ({case_score:.4f})") + self.record_profiler() return score / num_cases, metrics def __call__(self, *args, **kwargs) -> None: @@ -535,4 +577,4 @@ def __call__(self, *args, **kwargs) -> None: @override def __str__(self) -> str: - return f"{self.__class__.__name__} {self._experiment_id}" + return f"{self._trainer_variant} {self._experiment_id}" diff --git a/pyproject.toml b/pyproject.toml index 1c0ccb9..9c9d412 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ authors = [ ] dependencies = [ "pyyaml", "torch", "torchvision", "ptflops", "numpy", "safetensors", "SimpleITK", "matplotlib", "rich", "pandas", - "requests" + "requests", "psutil" ] [project.optional-dependencies] diff --git a/volume_viewer.py b/volume_viewer.py new file mode 100644 index 0000000..1ba1eaf --- /dev/null +++ b/volume_viewer.py @@ -0,0 +1,19 @@ +from mipcandy import fast_load, visualize3d +from mipcandy.presets.segmentation import print_stats_of_class_ids + +case = "006" +image = fast_load(f"S:/SharedWeights/MIPCandy/valPreloaded/images/{case}.pt") +label = fast_load(f"S:/SharedWeights/MIPCandy/valPreloaded/labels/{case}.pt") +print(image.shape, label.shape) +print(label.min(), label.max()) +print_stats_of_class_ids(label, "label", 5) +visualize3d(image, blocking=True) +visualize3d(label, is_label=True, blocking=True) +output = fast_load("S:/SharedWeights/MIPCandy/UNetTrainer/20260213-12-de22-UseThisToDebug/worst_output.pt") +output = output.softmax(0) +print(output.shape) +output = output.argmax(dim=0, keepdim=True) +print(output.shape) +print(output.min(), output.max()) +print_stats_of_class_ids(output, "output", 5) +visualize3d(output, blocking=True)