diff --git a/spleenseg/transforms/transforms.py b/spleenseg/transforms/transforms.py index 139dd99..6f62470 100644 --- a/spleenseg/transforms/transforms.py +++ b/spleenseg/transforms/transforms.py @@ -28,20 +28,58 @@ from monai.transforms.croppad.dictionary import CropForegroundd, RandCropByPosNegLabeld from monai.transforms.spatial.dictionary import Orientationd, RandAffined, Spacingd from monai.transforms.intensity.dictionary import ScaleIntensityRanged -from typing import Any, Optional, Callable, Hashable, Mapping, Dict, Union +from monai.transforms.io.array import LoadImage +from monai.config.type_definitions import PathLike +from typing import Any, Callable, Hashable, Mapping, Sequence +from monai.data.meta_tensor import MetaTensor import numpy as np +from numpy import ndarray from pathlib import Path from spleenseg.plotting import plotting -def f_LoadImaged() -> Callable[[dict[str, Any]], dict[str, Any]]: - return LoadImaged(keys=["image", "label"]) +def f_LoadImaged( + keys: list[str] = ["image", "label"], +) -> Callable[[dict[str, Any]], dict[str, Any]]: + return LoadImaged(keys=keys) + + +def f_LoadImage() -> Callable[ + [PathLike | Sequence[PathLike]], + torch.Tensor + | Any + | MetaTensor + | tuple[ + torch.Tensor | Any | MetaTensor, + dict[Any, Any] + | Any + | ndarray[Any, Any] + | tuple[Any, ...] + | list[Any] + | bool + | str + | float + | int + | None, + ], +]: + return LoadImage() + + +def f_SaveImaged(outputDir: Path) -> Callable[[dict[str, Any]], dict[str, Any]]: + return SaveImaged( + keys="pred", + meta_keys="pred_meta_dict", + output_dir=str(outputDir), + output_postfix="seg", + resample=False, + ) -def f_EnsureChannelFirstd() -> ( - Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]] -): - return EnsureChannelFirstd(keys=["image", "label"]) +def f_EnsureChannelFirstd( + keys: list[str] = ["image", "label"], +) -> Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]: + return EnsureChannelFirstd(keys=keys) def f_ScaleIntensityRanged() -> ( @@ -52,25 +90,26 @@ def f_ScaleIntensityRanged() -> ( ) -def f_CropForegroundd() -> ( - Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]] -): - return CropForegroundd(keys=["image", "label"], source_key="image") +def f_CropForegroundd( + keys: list[str] = ["image", "label"], +) -> Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]: + return CropForegroundd(keys=keys, source_key="image", allow_smaller=True) -def f_Orientationd() -> ( - Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]] -): - return Orientationd(keys=["image", "label"], axcodes="RAS") +def f_Orientationd( + keys: list[str] = ["image", "label"], +) -> Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]: + return Orientationd(keys=keys, axcodes="RAS") -def f_Spaceingd() -> ( - Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]] -): +def f_Spaceingd( + keys: list[str] = ["image", "label"], + mode: tuple[str, ...] = ("bilinear", "nearest"), +) -> Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]: return Spacingd( - keys=["image", "label"], + keys=keys, pixdim=(1.5, 1.5, 2.0), - mode=("bilinear", "nearest"), + mode=mode, ) @@ -162,6 +201,30 @@ def trainingAndValidation_transformsSetup() -> tuple[Compose, Compose]: return trainingTransforms, validationTransforms +def validation_transformsOnOriginal() -> Compose: + transforms: list = [ + f_LoadImaged(), + f_EnsureChannelFirstd(), + f_Orientationd(["image"]), + f_Spaceingd(["image"], tuple(["bilinear"])), + f_ScaleIntensityRanged(), + f_CropForegroundd(["image"]), + ] + return transforms_build(transforms) + + +def inferenceUse_transforms() -> Compose: + transforms: list = [ + f_LoadImaged(["image"]), + f_EnsureChannelFirstd(["image"]), + f_Orientationd(["image"]), + f_Spaceingd(["image"], tuple(["bilinear"])), + f_ScaleIntensityRanged(), + f_CropForegroundd(["image"]), + ] + return transforms_build(transforms) + + def transforms_check( outputdir: Path, files: list[dict[str, str]], transforms: Compose ) -> bool: @@ -171,6 +234,9 @@ def transforms_check( if not check_data: return False image, label = (check_data["image"][0][0], check_data["label"][0][0]) - print(f"image shape: {image.shape}, label shape: {label.shape}") + print("") + print("Checking transforms... :") + print(f"sample image shape: {image.shape}") + print(f"sample label shape: {label.shape}") plotting.plot_imageAndLabel(image, label, outputdir / "exemplar_image_label.jpg") return True