diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..365d837 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,39 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] #TODO add 3.11 support (for 3.12 torch is not available yet) + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest + pip install -e . + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest diff --git a/.gitignore b/.gitignore index df6df1e..9265c2a 100644 --- a/.gitignore +++ b/.gitignore @@ -26,7 +26,6 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST - # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. @@ -130,5 +129,4 @@ dmypy.json .vscode poetry.lock - -.DS_Store \ No newline at end of file +.DS_Store diff --git a/README.md b/README.md index 35c72b3..145153d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ [![PyPI version panoptica](https://badge.fury.io/py/brainles-aurora.svg)](https://pypi.python.org/pypi/brainles-aurora/) [![Documentation Status](https://readthedocs.org/projects/brainles-aurora/badge/?version=latest)](http://brainles-aurora.readthedocs.io/?badge=latest) +[![tests](https://github.com/BrainLesion/AURORA/actions/workflows/tests.yml/badge.svg)](https://github.com/BrainLesion/AURORA/actions/workflows/tests.yml) # AURORA diff --git a/brainles_aurora/aux.py b/brainles_aurora/aux.py index 6e4a5b3..6e7a143 100644 --- a/brainles_aurora/aux.py +++ b/brainles_aurora/aux.py @@ -1,4 +1,4 @@ -from path import Path +from pathlib import Path import os diff --git a/brainles_aurora/inferer/constants.py b/brainles_aurora/inferer/constants.py new file mode 100644 index 0000000..a8a8c84 --- /dev/null +++ b/brainles_aurora/inferer/constants.py @@ -0,0 +1,45 @@ +from enum import Enum + + +class InferenceMode(str, Enum): + """Enum representing different modes of inference based on available image inputs.""" + + T1_T1C_T2_FLA = "t1-t1c-t2-fla" + T1_T1C_FLA = "t1-t1c-fla" + T1_T1C = "t1-t1c" + T1C_FLA = "t1c-fla" + T1C_O = "t1c-o" + FLA_O = "fla-o" + T1_O = "t1-o" + + +class ModelSelection(str, Enum): + """Enum representing different strategies for model selection.""" + + BEST = "best" + LAST = "last" + VANILLA = "vanilla" + + +class DataMode(str, Enum): + """Enum representing different modes for handling input and output data. + + Enum Values: + NIFTI_FILE (str): Input data is provided as NIFTI file paths/ output is writte to NIFTI files. + NUMPY (str): Input data is provided as NumPy arrays/ output is returned as NumPy arrays. + """ + + NIFTI_FILE = "NIFTI_FILEPATH" + NUMPY = "NP_NDARRAY" + + +# booleans indicate presence of files in order: T1 T1C T2 FLAIR +IMGS_TO_MODE_DICT = { + (True, True, True, True): InferenceMode.T1_T1C_T2_FLA, + (True, True, False, True): InferenceMode.T1_T1C_FLA, + (True, True, False, False): InferenceMode.T1_T1C, + (False, True, False, True): InferenceMode.T1C_FLA, + (False, True, False, False): InferenceMode.T1C_O, + (False, False, False, True): InferenceMode.FLA_O, + (True, False, False, False): InferenceMode.T1_O, +} diff --git a/brainles_aurora/inferer/dataclasses.py b/brainles_aurora/inferer/dataclasses.py new file mode 100644 index 0000000..813f3b0 --- /dev/null +++ b/brainles_aurora/inferer/dataclasses.py @@ -0,0 +1,68 @@ +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Tuple + +import numpy as np + +from brainles_aurora.inferer.constants import DataMode, ModelSelection + + +@dataclass +class BaseConfig: + """Base configuration for the Aurora model inferer. + + Attributes: + output_mode (DataMode, optional): Output mode for the inference results. Defaults to DataMode.NIFTI_FILE. + output_folder (str | Path, optional): Output folder for the results. Defaults to "aurora_output". + log_level (int | str, optional): Logging level. Defaults to logging.INFO. + segmentation_file_name (str, optional): File name for the segmentation result. Defaults to "segmentation.nii.gz". (The segmentation will be saved in output_folder/{timestamp}/segmentation_file_name) + t1 (str | Path | np.ndarray | None, optional): Path or NumPy array for T1 image. Defaults to None. + t1c (str | Path | np.ndarray | None, optional): Path or NumPy array for T1 contrast-enhanced image. Defaults to None. + t2 (str | Path | np.ndarray | None, optional): Path or NumPy array for T2 image. Defaults to None. + fla (str | Path | np.ndarray | None, optional): Path or NumPy array for FLAIR image. Defaults to None. + """ + + output_mode: DataMode = DataMode.NIFTI_FILE + output_folder: str | Path = "aurora_output" + segmentation_file_name: str | None = "segmentation.nii.gz" + log_level: int | str = logging.INFO + t1: str | Path | np.ndarray | None = None + t1c: str | Path | np.ndarray | None = None + t2: str | Path | np.ndarray | None = None + fla: str | Path | np.ndarray | None = None + + +@dataclass +class AuroraInfererConfig(BaseConfig): + """Configuration for the Aurora model inferer. + + Attributes: + output_mode (DataMode, optional): Output mode for the inference results. Defaults to DataMode.NIFTI_FILE. + output_folder (str | Path, optional): Output folder for the results. Defaults to "aurora_output". + segmentation_file_name (str, optional): File name for the segmentation result. Defaults to "segmentation.nii.gz". (The segmentation will be saved in output_folder/{timestamp}/segmentation_file_name) + log_level (int | str, optional): Logging level. Defaults to logging.INFO. + t1 (str | Path | np.ndarray | None, optional): Path or NumPy array for T1 image. Defaults to None. + t1c (str | Path | np.ndarray | None, optional): Path or NumPy array for T1 contrast-enhanced image. Defaults to None. + t2 (str | Path | np.ndarray | None, optional): Path or NumPy array for T2 image. Defaults to None. + fla (str | Path | np.ndarray | None, optional): Path or NumPy array for FLAIR image. Defaults to None. + output_whole_network (bool, optional): Whether to output the whole network results. Defaults to False. + output_metastasis_network (bool, optional): Whether to output the metastasis network results. Defaults to False. + tta (bool, optional): Whether to apply test-time augmentations. Defaults to True. + sliding_window_batch_size (int, optional): Batch size for sliding window inference. Defaults to 1. + workers (int, optional): Number of workers for data loading. Defaults to 0. + threshold (float, optional): Threshold for binarizing the model outputs. Defaults to 0.5. + sliding_window_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.5. + crop_size (Tuple[int, int, int], optional): Crop size for sliding window inference. Defaults to (192, 192, 32). + model_selection (ModelSelection, optional): Model selection strategy. Defaults to ModelSelection.BEST. + """ + + output_whole_network: bool = False + output_metastasis_network: bool = False + tta: bool = True + sliding_window_batch_size: int = 1 + workers: int = 0 + threshold: float = 0.5 + sliding_window_overlap: float = 0.5 + crop_size: Tuple[int, int, int] = (192, 192, 32) + model_selection: ModelSelection = ModelSelection.BEST diff --git a/brainles_aurora/download.py b/brainles_aurora/inferer/download.py similarity index 96% rename from brainles_aurora/download.py rename to brainles_aurora/inferer/download.py index 3cdc833..34a5ac6 100644 --- a/brainles_aurora/download.py +++ b/brainles_aurora/inferer/download.py @@ -1,10 +1,10 @@ # copied from https://github.com/Nordgaren/Github-Folder-Downloader/blob/master/gitdl.py import os -from github import Github, Repository, ContentFile -import requests - import shutil as sh +import requests +from github import ContentFile, Github, Repository + def download(c: ContentFile, out: str): r = requests.get(c.download_url) diff --git a/brainles_aurora/inferer/inferer.py b/brainles_aurora/inferer/inferer.py new file mode 100644 index 0000000..d642f9a --- /dev/null +++ b/brainles_aurora/inferer/inferer.py @@ -0,0 +1,500 @@ +import logging +import os +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path +import sys +from typing import Dict, List + +import monai +import nibabel as nib +import numpy as np +import torch +from monai.data import list_data_collate +from monai.inferers import SlidingWindowInferer +from monai.networks.nets import BasicUNet +from monai.transforms import ( + Compose, + EnsureChannelFirstd, + Lambdad, + LoadImageD, + RandGaussianNoised, + ScaleIntensityRangePercentilesd, + ToTensord, +) +from torch.utils.data import DataLoader + +from brainles_aurora.inferer.constants import IMGS_TO_MODE_DICT, DataMode, InferenceMode +from brainles_aurora.aux import turbo_path +from brainles_aurora.inferer.dataclasses import AuroraInfererConfig, BaseConfig +from brainles_aurora.inferer.download import download_model_weights + +LIB_ABSPATH: str = os.path.dirname(os.path.abspath(__file__)) + +MODEL_WEIGHTS_DIR = Path(LIB_ABSPATH).parent / "model_weights" +if not MODEL_WEIGHTS_DIR.exists(): + download_model_weights(target_folder=LIB_ABSPATH) + + +class AbstractInferer(ABC): + """ + Abstract base class for inference. + + Attributes: + config (BaseConfig): The configuration for the inferer. + output_folder (Path): The output folder for the inferer. Follows the schema {config.output_folder}/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} + """ + + def __init__(self, config: BaseConfig) -> None: + """Initialize the abstract inferer. Sets up the logger and output folder. + + Args: + config (BaseConfig): Configuration for the inferer. + """ + self.config = config + + # setup output folder + self.output_folder = ( + Path(os.path.abspath(self.config.output_folder)) + / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ) + self.output_folder.mkdir(exist_ok=True, parents=True) + + # setup logger + self._setup_logger() + + def _setup_logger(self) -> None: + """Set up the logger for the inferer.""" + + self.log_path = self.output_folder / f"{self.config.segmentation_file_name}.log" + + logging.basicConfig( + # stream=sys.stderr, + format="%(asctime)s %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=self.config.log_level, + encoding="utf-8", + handlers=[ + logging.StreamHandler(), + logging.FileHandler(self.log_path), + ], + ) + + class CustomStdErrStream: + """Capture stderr and log it to the logger.""" + + def write(self, msg: str): + if msg := msg.rstrip(): + logging.error(msg) + + # sys.stderr = CustomStdErrStream() + + logging.info(f"Logging to: {self.log_path}") + + @abstractmethod + def infer(self): + pass + + +class AuroraInferer(AbstractInferer): + """Inferer for the CPU Aurora models.""" + + def __init__(self, config: AuroraInfererConfig) -> None: + """Initialize the AuroraInferer. + + Args: + config (AuroraInfererConfig): Configuration for the Aurora inferer. + """ + # TODO move weights path / download to config and setup + super().__init__(config=config) + + logging.info( + f"Initialized {self.__class__.__name__} with config: {self.config}" + ) + + self.images = self._validate_images() + self.mode = self._determine_inference_mode() + + self.device = self._configure_device() + logging.info("Setting up Dataloader") + self.data_loader = self._get_data_loader() + logging.info("Loading Model and weights") + self.model = self._get_model() + + def _validate_images(self) -> List[np.ndarray | None] | List[Path | None]: + """Validate input images, sets the input mode and returns the list of validated images. + + Returns: + List[np.ndarray | None] | List[Path | None]: List of validated images. + """ + + def _validate_image( + data: str | Path | np.ndarray | None, + ) -> np.ndarray | Path | None: + if data is None: + return None + if isinstance(data, np.ndarray): + self.input_mode = DataMode.NUMPY + return data.astype(np.float32) + if not os.path.exists(data): + raise FileNotFoundError(f"File {data} not found") + if not (data.endswith(".nii.gz") or data.endswith(".nii")): + raise ValueError( + f"File {data} must be a nifti file with extension .nii or .nii.gz" + ) + self.input_mode = DataMode.NIFTI_FILE + return turbo_path(data) + + images = [ + _validate_image(img) + for img in [ + self.config.t1, + self.config.t1c, + self.config.t2, + self.config.fla, + ] + ] + + not_none_images = [img for img in images if img is not None] + assert len(not_none_images) > 0, "No input images provided" + # make sure all inputs have the same type + unique_types = set(map(type, not_none_images)) + assert ( + len(unique_types) == 1 + ), f"All passed images must be of the same type! Received {unique_types}. Accepted Input types: {list(DataMode)}" + + logging.info( + f"Successfully validated input images. Input mode: {self.input_mode}" + ) + return images + + def _determine_inference_mode(self) -> InferenceMode: + """Determine the inference mode based on the provided images. + + Raises: + NotImplementedError: If no model is implemented for the combination of input images. + Returns: + InferenceMode: Inference mode based on the combination of input images. + """ + _t1, _t1c, _t2, _fla = [img is not None for img in self.images] + logging.info( + f"Received files: T1: {_t1}, T1C: {_t1c}, T2: {_t2}, FLAIR: {_fla}" + ) + + # check if files are given in a valid combination that has an existing model implementation + mode = IMGS_TO_MODE_DICT.get((_t1, _t1c, _t2, _fla), None) + + if mode is None: + raise NotImplementedError( + "No model implemented for this combination of images" + ) + + logging.info(f"Inference mode: {mode}") + return mode + + def _get_data_loader(self) -> torch.utils.data.DataLoader: + """Get the data loader for inference. + + Returns: + torch.utils.data.DataLoader: Data loader for inference. + """ + # init transforms + transforms = [ + LoadImageD(keys=["images"]) + if self.input_mode == DataMode.NIFTI_FILE + else None, + EnsureChannelFirstd(keys="images") + if len(self._get_not_none_files()) == 1 + else None, + Lambdad(["images"], np.nan_to_num), + ScaleIntensityRangePercentilesd( + keys="images", + lower=0.5, + upper=99.5, + b_min=0, + b_max=1, + clip=True, + relative=False, + channel_wise=True, + ), + ToTensord(keys=["images"]), + ] + # Filter None transforms + transforms = list(filter(None, transforms)) + inference_transforms = Compose(transforms) + + # Initialize data dictionary + data = { + key: getattr(self.config, key) + for key in ["t1", "t1c", "t2", "fla"] + if getattr(self.config, key) is not None + } + # method returns files in standard order T1 T1C T2 FLAIR + data["images"] = self._get_not_none_files() + + # init dataset and dataloader + infererence_ds = monai.data.Dataset( + data=[data], + transform=inference_transforms, + ) + + data_loader = DataLoader( + infererence_ds, + batch_size=1, + num_workers=self.config.workers, + collate_fn=list_data_collate, + shuffle=False, + ) + return data_loader + + def _get_model(self) -> torch.nn.Module: + """Get the Aurora model based on the inference mode. + + Returns: + torch.nn.Module: Aurora model. + """ + + # fuckery: + x = 70 / 0 + # init model + model = BasicUNet( + spatial_dims=3, + in_channels=len(self._get_not_none_files()), + out_channels=2, + features=(32, 32, 64, 128, 256, 32), + dropout=0.1, + act="mish", + ) + + # load weights + weights_path = os.path.join( + MODEL_WEIGHTS_DIR, + self.mode, + f"{self.config.model_selection}.tar", + ) + + if not os.path.exists(weights_path): + raise NotImplementedError( + f"No weights found for model {self.mode} and selection {self.config.model_selection}" + ) + + model = model.to(self.device) + checkpoint = torch.load(weights_path, map_location=self.device) + + # The models were trained using DataParallel, hence we need to remove the 'module.' prefix + # for cpu inference to enable checkpoint loading (since DataParallel is not usable for CPU) + if self.device == torch.device("cpu"): + if "module." in list(checkpoint["model_state"].keys())[0]: + checkpoint["model_state"] = { + k.replace("module.", ""): v + for k, v in checkpoint["model_state"].items() + } + else: + model = torch.nn.parallel.DataParallel(model) + + model.load_state_dict(checkpoint["model_state"]) + + return model + + def _apply_test_time_augmentations( + self, outputs: torch.Tensor, data: Dict, inferer: SlidingWindowInferer + ) -> torch.Tensor: + """Apply test time augmentations to the model outputs. + + Args: + outputs (torch.Tensor): Model outputs. + data (Dict): Input data. + inferer (SlidingWindowInferer): Sliding window inferer. + + Returns: + torch.Tensor: Augmented model outputs. + """ + n = 1.0 + for _ in range(4): + # test time augmentations + _img = RandGaussianNoised(keys="images", prob=1.0, std=0.001)(data)[ + "images" + ] + + output = inferer(_img, self.model) + outputs = outputs + output + n += 1.0 + for dims in [[2], [3]]: + flip_pred = inferer(torch.flip(_img, dims=dims), self.model) + + output = torch.flip(flip_pred, dims=dims) + outputs = outputs + output + n += 1.0 + outputs = outputs / n + return outputs + + def _get_not_none_files(self) -> List[np.ndarray] | List[Path]: + """Get the list of non-None input images in order T1-T1C-T2-FLA. + + Returns: + List[np.ndarray] | List[Path]: List of non-None images. + """ + return [img for img in self.images if img is not None] + + def _save_as_nifti(self, postproc_data: Dict[str, np.ndarray]) -> None: + """Save post-processed data as NIFTI files. + + Args: + postproc_data (Dict[str, np.ndarray]): Post-processed data. + """ + # determine affine/ header + if self.input_mode == DataMode.NIFTI_FILE: + reference_file = self._get_not_none_files()[0] + ref = nib.load(reference_file) + affine, header = ref.affine, ref.header + else: + logging.warning( + f"Writing NIFTI output after NumPy input, using default affine=np.eye(4) and header=None" + ) + affine, header = np.eye(4), None + + logging.info(f"Output folder set to {self.output_folder}") + + # save niftis + for key, data in postproc_data.items(): + # TODO: verify and make enum? + if key == "segmentation": + output_file = self.output_folder / self.config.segmentation_file_name + else: + output_file = self.output_folder / f"{key}.nii.gz" + output_image = nib.Nifti1Image(data, affine, header) + nib.save(output_image, output_file) + logging.info(f"Saved {key} to {output_file}") + + def _post_process( + self, onehot_model_outputs_CHWD: torch.Tensor + ) -> Dict[str, np.ndarray]: + """Post-process the model outputs. + + Args: + onehot_model_outputs_CHWD (torch.Tensor): One-hot encoded model outputs. + + Returns: + Dict[str, np.ndarray]: Post-processed data. + """ + + # create segmentations + activated_outputs = ( + (onehot_model_outputs_CHWD[0][:, :, :, :].sigmoid()).detach().cpu().numpy() + ) + binarized_outputs = activated_outputs >= self.config.threshold + binarized_outputs = binarized_outputs.astype(np.uint8) + + whole_metastasis = binarized_outputs[0] + enhancing_metastasis = binarized_outputs[1] + + final_seg = whole_metastasis.copy() + final_seg[whole_metastasis == 1] = 1 # edema + final_seg[enhancing_metastasis == 1] = 2 # enhancing + + whole_out = binarized_outputs[0] + enhancing_out = binarized_outputs[1] + + # create output dict based on config + data = {"segmentation": final_seg} + if self.config.output_whole_network: + data["output_whole_network"] = whole_out + if self.config.output_metastasis_network: + data["output_metastasis_network"] = enhancing_out + return data + + def _sliding_window_inference(self) -> None | Dict[str, np.ndarray]: + """Perform sliding window inference using monai.inferers.SlidingWindowInferer. + + Returns: + None | Dict[str, np.ndarray]: Post-processed data if output_mode is NUMPY, otherwise the data is saved as a niftis and None is returned. + """ + inferer = SlidingWindowInferer( + roi_size=self.config.crop_size, # = patch_size + sw_batch_size=self.config.sliding_window_batch_size, + sw_device=self.device, + device=self.device, + overlap=self.config.sliding_window_overlap, + mode="gaussian", + padding_mode="replicate", + ) + + with torch.no_grad(): + self.model.eval() + self.model = self.model.to(self.device) + # loop through batches, only 1 batch! + for data in self.data_loader: + inputs = data["images"].to(self.device) + + outputs = inferer(inputs, self.model) + if self.config.tta: + logging.info("Applying test time augmentations") + outputs = self._apply_test_time_augmentations( + outputs, data, inferer + ) + + postprocessed_data = self._post_process( + onehot_model_outputs_CHWD=outputs, + ) + if self.config.output_mode == DataMode.NUMPY: + return postprocessed_data + else: + self._save_as_nifti(postproc_data=postprocessed_data) + return + + def _configure_device(self) -> torch.device: + """Configure the device for inference. + + Returns: + torch.device: Configured device. + """ + device = torch.device("cpu") + logging.info(f"Using device: {device}") + return device + + def infer(self) -> None: + """Run the inference process.""" + logging.info(f"Running inference on {self.device}") + return self._sliding_window_inference() + + +#################### +# GPU Inferer +#################### +class AuroraGPUInferer(AuroraInferer): + """Inferer for the Aurora models on GPU.""" + + def __init__( + self, + config: AuroraInfererConfig, + cuda_devices: str = "0", + ) -> None: + """Initialize the AuroraGPUInferer. + + Args: + config (AuroraInfererConfig): Configuration for the Aurora GPU inferer. + cuda_devices (str, optional): CUDA devices to use. Defaults to "0". + """ + self.cuda_devices = cuda_devices + + super().__init__(config=config) + + def _configure_device(self) -> torch.device: + """Configure the GPU device for inference. + + Returns: + torch.device: Configured GPU device. + """ + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_devices + + assert ( + torch.cuda.is_available() + ), "No cuda device available while using GPUInferer" + + device = torch.device("cuda") + logging.info(f"Using device: {device}") + + # clean memory + torch.cuda.empty_cache() + return device diff --git a/brainles_aurora/lib.py b/brainles_aurora/lib.py index 7255c27..cb510b2 100644 --- a/brainles_aurora/lib.py +++ b/brainles_aurora/lib.py @@ -437,7 +437,8 @@ def single_inference( metastasis_network_outputs_file=None, cuda_devices="0", tta=True, - sliding_window_batch_size=1, # faster for single interference (on RTX 3090) + # faster for single interference (on RTX 3090) + sliding_window_batch_size=1, workers=0, threshold=0.5, sliding_window_overlap=0.5, diff --git a/brainles_aurora/model_weights/fla-o/fla-o_best.tar b/brainles_aurora/model_weights/fla-o/best.tar similarity index 100% rename from brainles_aurora/model_weights/fla-o/fla-o_best.tar rename to brainles_aurora/model_weights/fla-o/best.tar diff --git a/brainles_aurora/model_weights/fla-o/fla-o_last.tar b/brainles_aurora/model_weights/fla-o/last.tar similarity index 100% rename from brainles_aurora/model_weights/fla-o/fla-o_last.tar rename to brainles_aurora/model_weights/fla-o/last.tar diff --git a/brainles_aurora/model_weights/t1-o/t1-o_best.tar b/brainles_aurora/model_weights/t1-o/best.tar similarity index 100% rename from brainles_aurora/model_weights/t1-o/t1-o_best.tar rename to brainles_aurora/model_weights/t1-o/best.tar diff --git a/brainles_aurora/model_weights/t1-o/t1-o_last.tar b/brainles_aurora/model_weights/t1-o/last.tar similarity index 100% rename from brainles_aurora/model_weights/t1-o/t1-o_last.tar rename to brainles_aurora/model_weights/t1-o/last.tar diff --git a/brainles_aurora/model_weights/t1c-t1-fla/t1c-t1-fla_best.tar b/brainles_aurora/model_weights/t1-t1c-fla/best.tar similarity index 100% rename from brainles_aurora/model_weights/t1c-t1-fla/t1c-t1-fla_best.tar rename to brainles_aurora/model_weights/t1-t1c-fla/best.tar diff --git a/brainles_aurora/model_weights/t1c-t1-fla/t1c-t1-fla_last.tar b/brainles_aurora/model_weights/t1-t1c-fla/last.tar similarity index 100% rename from brainles_aurora/model_weights/t1c-t1-fla/t1c-t1-fla_last.tar rename to brainles_aurora/model_weights/t1-t1c-fla/last.tar diff --git a/brainles_aurora/model_weights/t1-t1c-t2-fla/t1-t1c-t2-fla_best.tar b/brainles_aurora/model_weights/t1-t1c-t2-fla/best.tar similarity index 100% rename from brainles_aurora/model_weights/t1-t1c-t2-fla/t1-t1c-t2-fla_best.tar rename to brainles_aurora/model_weights/t1-t1c-t2-fla/best.tar diff --git a/brainles_aurora/model_weights/t1-t1c-t2-fla/t1-t1c-t2-fla_last.tar b/brainles_aurora/model_weights/t1-t1c-t2-fla/last.tar similarity index 100% rename from brainles_aurora/model_weights/t1-t1c-t2-fla/t1-t1c-t2-fla_last.tar rename to brainles_aurora/model_weights/t1-t1c-t2-fla/last.tar diff --git a/brainles_aurora/model_weights/t1c-t1/t1c-t1_best.tar b/brainles_aurora/model_weights/t1-t1c/best.tar similarity index 100% rename from brainles_aurora/model_weights/t1c-t1/t1c-t1_best.tar rename to brainles_aurora/model_weights/t1-t1c/best.tar diff --git a/brainles_aurora/model_weights/t1c-t1/t1c-t1_last.tar b/brainles_aurora/model_weights/t1-t1c/last.tar similarity index 100% rename from brainles_aurora/model_weights/t1c-t1/t1c-t1_last.tar rename to brainles_aurora/model_weights/t1-t1c/last.tar diff --git a/brainles_aurora/model_weights/t1c-fla/t1c-fla_best.tar b/brainles_aurora/model_weights/t1c-fla/best.tar similarity index 100% rename from brainles_aurora/model_weights/t1c-fla/t1c-fla_best.tar rename to brainles_aurora/model_weights/t1c-fla/best.tar diff --git a/brainles_aurora/model_weights/t1c-fla/t1c-fla_last.tar b/brainles_aurora/model_weights/t1c-fla/last.tar similarity index 100% rename from brainles_aurora/model_weights/t1c-fla/t1c-fla_last.tar rename to brainles_aurora/model_weights/t1c-fla/last.tar diff --git a/brainles_aurora/model_weights/t1c-o/t1c-o_best.tar b/brainles_aurora/model_weights/t1c-o/best.tar similarity index 100% rename from brainles_aurora/model_weights/t1c-o/t1c-o_best.tar rename to brainles_aurora/model_weights/t1c-o/best.tar diff --git a/brainles_aurora/model_weights/t1c-o/t1c-o_last.tar b/brainles_aurora/model_weights/t1c-o/last.tar similarity index 100% rename from brainles_aurora/model_weights/t1c-o/t1c-o_last.tar rename to brainles_aurora/model_weights/t1c-o/last.tar diff --git a/pyproject.toml b/pyproject.toml index af4a3df..cf85f6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ license = "AGPL-3.0" authors = [ "Florian Kofler ", "Isra Mekki ", + "Marcel Rosier ", ] maintainers = [ @@ -38,10 +39,10 @@ exclude = ["brainles_aurora/model_weights"] [tool.poetry.dependencies] # core python = "^3.10" -monai = "^1.2.0" -torch = "^2.1.0" -nibabel = "^4.0.2" -numpy = "^1.23.0" +monai = ">=1.2.0" +torch = ">=2.1.0" +nibabel = ">=4.0.2" +numpy = ">=1.23.0" # utils PyGithub = "^1.57" diff --git a/segmentation_test.py b/segmentation_test.py index 4204a04..60a54d0 100644 --- a/segmentation_test.py +++ b/segmentation_test.py @@ -1,8 +1,80 @@ -from brainles_aurora.lib import single_inference - -single_inference( - t1c_file="example_data/BraTS-MET-00110-000-t1c.nii.gz", - segmentation_file="your_segmentation_file.nii.gz", - tta=False, # optional: whether to use test time augmentations - verbosity=True, # optional: verbosity of the output +from brainles_aurora.inferer.constants import DataMode +from brainles_aurora.inferer.inferer import ( + AuroraInferer, + AuroraGPUInferer, + AuroraInfererConfig, ) +import os +from path import Path +import nibabel as nib + +BASE_PATH = Path(os.path.abspath(__file__)).parent + +t1 = BASE_PATH / "example_data/BraTS-MET-00110-000-t1n.nii.gz" +t1c = BASE_PATH / "example_data/BraTS-MET-00110-000-t1c.nii.gz" +t2 = BASE_PATH / "example_data/BraTS-MET-00110-000-t2w.nii.gz" +fla = BASE_PATH / "example_data/BraTS-MET-00110-000-t2f.nii.gz" + + +def load_np_from_nifti(path): + return nib.load(path).get_fdata() + + +def gpu_nifti(): + config = AuroraInfererConfig( + t1=t1, + t1c=t1c, + t2=t2, + fla=fla, + output_metastasis_network=True, + output_whole_network=True, + ) + inferer = AuroraGPUInferer( + config=config, + ) + inferer.infer() + + +def cpu_nifti(): + config = AuroraInfererConfig( + t1=t1, + t1c=t1c, + t2=t2, + fla=fla, + ) + inferer = AuroraInferer( + config=config, + ) + inferer.infer() + + +def gpu_np(): + config = AuroraInfererConfig( + t1=load_np_from_nifti(t1), + t1c=load_np_from_nifti(t1c), + t2=load_np_from_nifti(t2), + fla=load_np_from_nifti(fla), + ) + inferer = AuroraGPUInferer( + config=config, + ) + inferer.infer() + + +def gpu_output_np(): + config = AuroraInfererConfig( + t1=load_np_from_nifti(t1), + t1c=load_np_from_nifti(t1c), + t2=load_np_from_nifti(t2), + fla=load_np_from_nifti(fla), + output_mode=DataMode.NUMPY, + ) + inferer = AuroraGPUInferer( + config=config, + ) + data = inferer.infer() + print(data) + + +if __name__ == "__main__": + gpu_nifti() diff --git a/tester.ipynb b/tester.ipynb new file mode 100644 index 0000000..40a3024 --- /dev/null +++ b/tester.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "t1: False t1c: True t2: False flair: False\n", + "mode: t1c-o\n", + "BasicUNet features: (32, 32, 64, 128, 256, 32).\n" + ] + } + ], + "source": [ + "from brainles_aurora.core import infer\n", + "from brainles_aurora.lib import single_inference\n", + "from brainles_aurora.aux import turbo_path\n", + "from brainles_aurora.enums import InferenceMode\n", + "\n", + "\n", + "t1c_file = \"example_data/BraTS-MET-00110-000-t1c.nii.gz\"\n", + "segmentation_file = \"your_segmentation_file.nii.gz\"\n", + "tta = False # optional: whether to use test time augmentations\n", + "verbosity = True # optional: verbosity of the output\n", + "\n", + "c = single_inference(\n", + " t1c_file=t1c_file, segmentation_file=segmentation_file, tta=tta, verbosity=verbosity\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "t1: False t1c: True t2: False flair: False\n", + "mode: InferenceMode.T1C_O\n", + "BasicUNet features: (32, 32, 64, 128, 256, 32).\n", + "odict_keys(['module.conv_0.conv_0.conv.weight', 'module.conv_0.conv_0.conv.bias', 'module.conv_0.conv_0.adn.N.weight', 'module.conv_0.conv_0.adn.N.bias', 'module.conv_0.conv_1.conv.weight', 'module.conv_0.conv_1.conv.bias', 'module.conv_0.conv_1.adn.N.weight', 'module.conv_0.conv_1.adn.N.bias', 'module.down_1.convs.conv_0.conv.weight', 'module.down_1.convs.conv_0.conv.bias', 'module.down_1.convs.conv_0.adn.N.weight', 'module.down_1.convs.conv_0.adn.N.bias', 'module.down_1.convs.conv_1.conv.weight', 'module.down_1.convs.conv_1.conv.bias', 'module.down_1.convs.conv_1.adn.N.weight', 'module.down_1.convs.conv_1.adn.N.bias', 'module.down_2.convs.conv_0.conv.weight', 'module.down_2.convs.conv_0.conv.bias', 'module.down_2.convs.conv_0.adn.N.weight', 'module.down_2.convs.conv_0.adn.N.bias', 'module.down_2.convs.conv_1.conv.weight', 'module.down_2.convs.conv_1.conv.bias', 'module.down_2.convs.conv_1.adn.N.weight', 'module.down_2.convs.conv_1.adn.N.bias', 'module.down_3.convs.conv_0.conv.weight', 'module.down_3.convs.conv_0.conv.bias', 'module.down_3.convs.conv_0.adn.N.weight', 'module.down_3.convs.conv_0.adn.N.bias', 'module.down_3.convs.conv_1.conv.weight', 'module.down_3.convs.conv_1.conv.bias', 'module.down_3.convs.conv_1.adn.N.weight', 'module.down_3.convs.conv_1.adn.N.bias', 'module.down_4.convs.conv_0.conv.weight', 'module.down_4.convs.conv_0.conv.bias', 'module.down_4.convs.conv_0.adn.N.weight', 'module.down_4.convs.conv_0.adn.N.bias', 'module.down_4.convs.conv_1.conv.weight', 'module.down_4.convs.conv_1.conv.bias', 'module.down_4.convs.conv_1.adn.N.weight', 'module.down_4.convs.conv_1.adn.N.bias', 'module.upcat_4.upsample.deconv.weight', 'module.upcat_4.upsample.deconv.bias', 'module.upcat_4.convs.conv_0.conv.weight', 'module.upcat_4.convs.conv_0.conv.bias', 'module.upcat_4.convs.conv_0.adn.N.weight', 'module.upcat_4.convs.conv_0.adn.N.bias', 'module.upcat_4.convs.conv_1.conv.weight', 'module.upcat_4.convs.conv_1.conv.bias', 'module.upcat_4.convs.conv_1.adn.N.weight', 'module.upcat_4.convs.conv_1.adn.N.bias', 'module.upcat_3.upsample.deconv.weight', 'module.upcat_3.upsample.deconv.bias', 'module.upcat_3.convs.conv_0.conv.weight', 'module.upcat_3.convs.conv_0.conv.bias', 'module.upcat_3.convs.conv_0.adn.N.weight', 'module.upcat_3.convs.conv_0.adn.N.bias', 'module.upcat_3.convs.conv_1.conv.weight', 'module.upcat_3.convs.conv_1.conv.bias', 'module.upcat_3.convs.conv_1.adn.N.weight', 'module.upcat_3.convs.conv_1.adn.N.bias', 'module.upcat_2.upsample.deconv.weight', 'module.upcat_2.upsample.deconv.bias', 'module.upcat_2.convs.conv_0.conv.weight', 'module.upcat_2.convs.conv_0.conv.bias', 'module.upcat_2.convs.conv_0.adn.N.weight', 'module.upcat_2.convs.conv_0.adn.N.bias', 'module.upcat_2.convs.conv_1.conv.weight', 'module.upcat_2.convs.conv_1.conv.bias', 'module.upcat_2.convs.conv_1.adn.N.weight', 'module.upcat_2.convs.conv_1.adn.N.bias', 'module.upcat_1.upsample.deconv.weight', 'module.upcat_1.upsample.deconv.bias', 'module.upcat_1.convs.conv_0.conv.weight', 'module.upcat_1.convs.conv_0.conv.bias', 'module.upcat_1.convs.conv_0.adn.N.weight', 'module.upcat_1.convs.conv_0.adn.N.bias', 'module.upcat_1.convs.conv_1.conv.weight', 'module.upcat_1.convs.conv_1.conv.bias', 'module.upcat_1.convs.conv_1.adn.N.weight', 'module.upcat_1.convs.conv_1.adn.N.bias', 'module.final_conv.weight', 'module.final_conv.bias'])\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "Error(s) in loading state_dict for BasicUNet:\n\tMissing key(s) in state_dict: \"conv_0.conv_0.conv.weight\", \"conv_0.conv_0.conv.bias\", \"conv_0.conv_0.adn.N.weight\", \"conv_0.conv_0.adn.N.bias\", \"conv_0.conv_1.conv.weight\", \"conv_0.conv_1.conv.bias\", \"conv_0.conv_1.adn.N.weight\", \"conv_0.conv_1.adn.N.bias\", \"down_1.convs.conv_0.conv.weight\", \"down_1.convs.conv_0.conv.bias\", \"down_1.convs.conv_0.adn.N.weight\", \"down_1.convs.conv_0.adn.N.bias\", \"down_1.convs.conv_1.conv.weight\", \"down_1.convs.conv_1.conv.bias\", \"down_1.convs.conv_1.adn.N.weight\", \"down_1.convs.conv_1.adn.N.bias\", \"down_2.convs.conv_0.conv.weight\", \"down_2.convs.conv_0.conv.bias\", \"down_2.convs.conv_0.adn.N.weight\", \"down_2.convs.conv_0.adn.N.bias\", \"down_2.convs.conv_1.conv.weight\", \"down_2.convs.conv_1.conv.bias\", \"down_2.convs.conv_1.adn.N.weight\", \"down_2.convs.conv_1.adn.N.bias\", \"down_3.convs.conv_0.conv.weight\", \"down_3.convs.conv_0.conv.bias\", \"down_3.convs.conv_0.adn.N.weight\", \"down_3.convs.conv_0.adn.N.bias\", \"down_3.convs.conv_1.conv.weight\", \"down_3.convs.conv_1.conv.bias\", \"down_3.convs.conv_1.adn.N.weight\", \"down_3.convs.conv_1.adn.N.bias\", \"down_4.convs.conv_0.conv.weight\", \"down_4.convs.conv_0.conv.bias\", \"down_4.convs.conv_0.adn.N.weight\", \"down_4.convs.conv_0.adn.N.bias\", \"down_4.convs.conv_1.conv.weight\", \"down_4.convs.conv_1.conv.bias\", \"down_4.convs.conv_1.adn.N.weight\", \"down_4.convs.conv_1.adn.N.bias\", \"upcat_4.upsample.deconv.weight\", \"upcat_4.upsample.deconv.bias\", \"upcat_4.convs.conv_0.conv.weight\", \"upcat_4.convs.conv_0.conv.bias\", \"upcat_4.convs.conv_0.adn.N.weight\", \"upcat_4.convs.conv_0.adn.N.bias\", \"upcat_4.convs.conv_1.conv.weight\", \"upcat_4.convs.conv_1.conv.bias\", \"upcat_4.convs.conv_1.adn.N.weight\", \"upcat_4.convs.conv_1.adn.N.bias\", \"upcat_3.upsample.deconv.weight\", \"upcat_3.upsample.deconv.bias\", \"upcat_3.convs.conv_0.conv.weight\", \"upcat_3.convs.conv_0.conv.bias\", \"upcat_3.convs.conv_0.adn.N.weight\", \"upcat_3.convs.conv_0.adn.N.bias\", \"upcat_3.convs.conv_1.conv.weight\", \"upcat_3.convs.conv_1.conv.bias\", \"upcat_3.convs.conv_1.adn.N.weight\", \"upcat_3.convs.conv_1.adn.N.bias\", \"upcat_2.upsample.deconv.weight\", \"upcat_2.upsample.deconv.bias\", \"upcat_2.convs.conv_0.conv.weight\", \"upcat_2.convs.conv_0.conv.bias\", \"upcat_2.convs.conv_0.adn.N.weight\", \"upcat_2.convs.conv_0.adn.N.bias\", \"upcat_2.convs.conv_1.conv.weight\", \"upcat_2.convs.conv_1.conv.bias\", \"upcat_2.convs.conv_1.adn.N.weight\", \"upcat_2.convs.conv_1.adn.N.bias\", \"upcat_1.upsample.deconv.weight\", \"upcat_1.upsample.deconv.bias\", \"upcat_1.convs.conv_0.conv.weight\", \"upcat_1.convs.conv_0.conv.bias\", \"upcat_1.convs.conv_0.adn.N.weight\", \"upcat_1.convs.conv_0.adn.N.bias\", \"upcat_1.convs.conv_1.conv.weight\", \"upcat_1.convs.conv_1.conv.bias\", \"upcat_1.convs.conv_1.adn.N.weight\", \"upcat_1.convs.conv_1.adn.N.bias\", \"final_conv.weight\", \"final_conv.bias\". \n\tUnexpected key(s) in state_dict: \"module.conv_0.conv_0.conv.weight\", \"module.conv_0.conv_0.conv.bias\", \"module.conv_0.conv_0.adn.N.weight\", \"module.conv_0.conv_0.adn.N.bias\", \"module.conv_0.conv_1.conv.weight\", \"module.conv_0.conv_1.conv.bias\", \"module.conv_0.conv_1.adn.N.weight\", \"module.conv_0.conv_1.adn.N.bias\", \"module.down_1.convs.conv_0.conv.weight\", \"module.down_1.convs.conv_0.conv.bias\", \"module.down_1.convs.conv_0.adn.N.weight\", \"module.down_1.convs.conv_0.adn.N.bias\", \"module.down_1.convs.conv_1.conv.weight\", \"module.down_1.convs.conv_1.conv.bias\", \"module.down_1.convs.conv_1.adn.N.weight\", \"module.down_1.convs.conv_1.adn.N.bias\", \"module.down_2.convs.conv_0.conv.weight\", \"module.down_2.convs.conv_0.conv.bias\", \"module.down_2.convs.conv_0.adn.N.weight\", \"module.down_2.convs.conv_0.adn.N.bias\", \"module.down_2.convs.conv_1.conv.weight\", \"module.down_2.convs.conv_1.conv.bias\", \"module.down_2.convs.conv_1.adn.N.weight\", \"module.down_2.convs.conv_1.adn.N.bias\", \"module.down_3.convs.conv_0.conv.weight\", \"module.down_3.convs.conv_0.conv.bias\", \"module.down_3.convs.conv_0.adn.N.weight\", \"module.down_3.convs.conv_0.adn.N.bias\", \"module.down_3.convs.conv_1.conv.weight\", \"module.down_3.convs.conv_1.conv.bias\", \"module.down_3.convs.conv_1.adn.N.weight\", \"module.down_3.convs.conv_1.adn.N.bias\", \"module.down_4.convs.conv_0.conv.weight\", \"module.down_4.convs.conv_0.conv.bias\", \"module.down_4.convs.conv_0.adn.N.weight\", \"module.down_4.convs.conv_0.adn.N.bias\", \"module.down_4.convs.conv_1.conv.weight\", \"module.down_4.convs.conv_1.conv.bias\", \"module.down_4.convs.conv_1.adn.N.weight\", \"module.down_4.convs.conv_1.adn.N.bias\", \"module.upcat_4.upsample.deconv.weight\", \"module.upcat_4.upsample.deconv.bias\", \"module.upcat_4.convs.conv_0.conv.weight\", \"module.upcat_4.convs.conv_0.conv.bias\", \"module.upcat_4.convs.conv_0.adn.N.weight\", \"module.upcat_4.convs.conv_0.adn.N.bias\", \"module.upcat_4.convs.conv_1.conv.weight\", \"module.upcat_4.convs.conv_1.conv.bias\", \"module.upcat_4.convs.conv_1.adn.N.weight\", \"module.upcat_4.convs.conv_1.adn.N.bias\", \"module.upcat_3.upsample.deconv.weight\", \"module.upcat_3.upsample.deconv.bias\", \"module.upcat_3.convs.conv_0.conv.weight\", \"module.upcat_3.convs.conv_0.conv.bias\", \"module.upcat_3.convs.conv_0.adn.N.weight\", \"module.upcat_3.convs.conv_0.adn.N.bias\", \"module.upcat_3.convs.conv_1.conv.weight\", \"module.upcat_3.convs.conv_1.conv.bias\", \"module.upcat_3.convs.conv_1.adn.N.weight\", \"module.upcat_3.convs.conv_1.adn.N.bias\", \"module.upcat_2.upsample.deconv.weight\", \"module.upcat_2.upsample.deconv.bias\", \"module.upcat_2.convs.conv_0.conv.weight\", \"module.upcat_2.convs.conv_0.conv.bias\", \"module.upcat_2.convs.conv_0.adn.N.weight\", \"module.upcat_2.convs.conv_0.adn.N.bias\", \"module.upcat_2.convs.conv_1.conv.weight\", \"module.upcat_2.convs.conv_1.conv.bias\", \"module.upcat_2.convs.conv_1.adn.N.weight\", \"module.upcat_2.convs.conv_1.adn.N.bias\", \"module.upcat_1.upsample.deconv.weight\", \"module.upcat_1.upsample.deconv.bias\", \"module.upcat_1.convs.conv_0.conv.weight\", \"module.upcat_1.convs.conv_0.conv.bias\", \"module.upcat_1.convs.conv_0.adn.N.weight\", \"module.upcat_1.convs.conv_0.adn.N.bias\", \"module.upcat_1.convs.conv_1.conv.weight\", \"module.upcat_1.convs.conv_1.conv.bias\", \"module.upcat_1.convs.conv_1.adn.N.weight\", \"module.upcat_1.convs.conv_1.adn.N.bias\", \"module.final_conv.weight\", \"module.final_conv.bias\". ", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/marcelrosier/Projects/helmholtz/AURORA/tester.ipynb Cell 2\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m c2 \u001b[39m=\u001b[39m infer(t1c_file\u001b[39m=\u001b[39;49mt1c_file,\n\u001b[1;32m 2\u001b[0m segmentation_file\u001b[39m=\u001b[39;49msegmentation_file,\n\u001b[1;32m 3\u001b[0m tta\u001b[39m=\u001b[39;49mtta,\n\u001b[1;32m 4\u001b[0m verbosity\u001b[39m=\u001b[39;49mverbosity\n\u001b[1;32m 5\u001b[0m )\n", + "File \u001b[0;32m~/Projects/helmholtz/AURORA/brainles_aurora/core.py:78\u001b[0m, in \u001b[0;36minfer\u001b[0;34m(segmentation_file, t1_file, t1c_file, t2_file, fla_file, whole_network_outputs_file, metastasis_network_outputs_file, cuda_devices, tta, sliding_window_batch_size, workers, threshold, sliding_window_overlap, crop_size, model_selection, verbosity)\u001b[0m\n\u001b[1;32m 69\u001b[0m data_loader \u001b[39m=\u001b[39m _get_dloader(\n\u001b[1;32m 70\u001b[0m mode\u001b[39m=\u001b[39mmode,\n\u001b[1;32m 71\u001b[0m t1_file\u001b[39m=\u001b[39mt1_file,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 75\u001b[0m workers\u001b[39m=\u001b[39mworkers,\n\u001b[1;32m 76\u001b[0m )\n\u001b[1;32m 77\u001b[0m \u001b[39m# load model\u001b[39;00m\n\u001b[0;32m---> 78\u001b[0m model \u001b[39m=\u001b[39m _get_model(\n\u001b[1;32m 79\u001b[0m mode\u001b[39m=\u001b[39;49mmode,\n\u001b[1;32m 80\u001b[0m model_selection\u001b[39m=\u001b[39;49mmodel_selection,\n\u001b[1;32m 81\u001b[0m device\u001b[39m=\u001b[39;49mdevice\n\u001b[1;32m 82\u001b[0m )\n\u001b[1;32m 83\u001b[0m \u001b[39mreturn\u001b[39;00m model\n\u001b[1;32m 84\u001b[0m \u001b[39m# create inferrer\u001b[39;00m\n", + "File \u001b[0;32m~/Projects/helmholtz/AURORA/brainles_aurora/core.py:290\u001b[0m, in \u001b[0;36m_get_model\u001b[0;34m(mode, model_selection, device)\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mexists(weights):\n\u001b[1;32m 287\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m(\n\u001b[1;32m 288\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mNo weights found for model \u001b[39m\u001b[39m{\u001b[39;00mmode\u001b[39m}\u001b[39;00m\u001b[39m and selection \u001b[39m\u001b[39m{\u001b[39;00mmodel_selection\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 290\u001b[0m model\u001b[39m.\u001b[39;49mload_state_dict(checkpoint[\u001b[39m\"\u001b[39;49m\u001b[39mmodel_state\u001b[39;49m\u001b[39m\"\u001b[39;49m])\n\u001b[1;32m 292\u001b[0m \u001b[39mreturn\u001b[39;00m model\n", + "File \u001b[0;32m~/opt/anaconda3/envs/brainles/lib/python3.10/site-packages/torch/nn/modules/module.py:2152\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2147\u001b[0m error_msgs\u001b[39m.\u001b[39minsert(\n\u001b[1;32m 2148\u001b[0m \u001b[39m0\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mMissing key(s) in state_dict: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2149\u001b[0m \u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2151\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m-> 2152\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2153\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2154\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for BasicUNet:\n\tMissing key(s) in state_dict: \"conv_0.conv_0.conv.weight\", \"conv_0.conv_0.conv.bias\", \"conv_0.conv_0.adn.N.weight\", \"conv_0.conv_0.adn.N.bias\", \"conv_0.conv_1.conv.weight\", \"conv_0.conv_1.conv.bias\", \"conv_0.conv_1.adn.N.weight\", \"conv_0.conv_1.adn.N.bias\", \"down_1.convs.conv_0.conv.weight\", \"down_1.convs.conv_0.conv.bias\", \"down_1.convs.conv_0.adn.N.weight\", \"down_1.convs.conv_0.adn.N.bias\", \"down_1.convs.conv_1.conv.weight\", \"down_1.convs.conv_1.conv.bias\", \"down_1.convs.conv_1.adn.N.weight\", \"down_1.convs.conv_1.adn.N.bias\", \"down_2.convs.conv_0.conv.weight\", \"down_2.convs.conv_0.conv.bias\", \"down_2.convs.conv_0.adn.N.weight\", \"down_2.convs.conv_0.adn.N.bias\", \"down_2.convs.conv_1.conv.weight\", \"down_2.convs.conv_1.conv.bias\", \"down_2.convs.conv_1.adn.N.weight\", \"down_2.convs.conv_1.adn.N.bias\", \"down_3.convs.conv_0.conv.weight\", \"down_3.convs.conv_0.conv.bias\", \"down_3.convs.conv_0.adn.N.weight\", \"down_3.convs.conv_0.adn.N.bias\", \"down_3.convs.conv_1.conv.weight\", \"down_3.convs.conv_1.conv.bias\", \"down_3.convs.conv_1.adn.N.weight\", \"down_3.convs.conv_1.adn.N.bias\", \"down_4.convs.conv_0.conv.weight\", \"down_4.convs.conv_0.conv.bias\", \"down_4.convs.conv_0.adn.N.weight\", \"down_4.convs.conv_0.adn.N.bias\", \"down_4.convs.conv_1.conv.weight\", \"down_4.convs.conv_1.conv.bias\", \"down_4.convs.conv_1.adn.N.weight\", \"down_4.convs.conv_1.adn.N.bias\", \"upcat_4.upsample.deconv.weight\", \"upcat_4.upsample.deconv.bias\", \"upcat_4.convs.conv_0.conv.weight\", \"upcat_4.convs.conv_0.conv.bias\", \"upcat_4.convs.conv_0.adn.N.weight\", \"upcat_4.convs.conv_0.adn.N.bias\", \"upcat_4.convs.conv_1.conv.weight\", \"upcat_4.convs.conv_1.conv.bias\", \"upcat_4.convs.conv_1.adn.N.weight\", \"upcat_4.convs.conv_1.adn.N.bias\", \"upcat_3.upsample.deconv.weight\", \"upcat_3.upsample.deconv.bias\", \"upcat_3.convs.conv_0.conv.weight\", \"upcat_3.convs.conv_0.conv.bias\", \"upcat_3.convs.conv_0.adn.N.weight\", \"upcat_3.convs.conv_0.adn.N.bias\", \"upcat_3.convs.conv_1.conv.weight\", \"upcat_3.convs.conv_1.conv.bias\", \"upcat_3.convs.conv_1.adn.N.weight\", \"upcat_3.convs.conv_1.adn.N.bias\", \"upcat_2.upsample.deconv.weight\", \"upcat_2.upsample.deconv.bias\", \"upcat_2.convs.conv_0.conv.weight\", \"upcat_2.convs.conv_0.conv.bias\", \"upcat_2.convs.conv_0.adn.N.weight\", \"upcat_2.convs.conv_0.adn.N.bias\", \"upcat_2.convs.conv_1.conv.weight\", \"upcat_2.convs.conv_1.conv.bias\", \"upcat_2.convs.conv_1.adn.N.weight\", \"upcat_2.convs.conv_1.adn.N.bias\", \"upcat_1.upsample.deconv.weight\", \"upcat_1.upsample.deconv.bias\", \"upcat_1.convs.conv_0.conv.weight\", \"upcat_1.convs.conv_0.conv.bias\", \"upcat_1.convs.conv_0.adn.N.weight\", \"upcat_1.convs.conv_0.adn.N.bias\", \"upcat_1.convs.conv_1.conv.weight\", \"upcat_1.convs.conv_1.conv.bias\", \"upcat_1.convs.conv_1.adn.N.weight\", \"upcat_1.convs.conv_1.adn.N.bias\", \"final_conv.weight\", \"final_conv.bias\". \n\tUnexpected key(s) in state_dict: \"module.conv_0.conv_0.conv.weight\", \"module.conv_0.conv_0.conv.bias\", \"module.conv_0.conv_0.adn.N.weight\", \"module.conv_0.conv_0.adn.N.bias\", \"module.conv_0.conv_1.conv.weight\", \"module.conv_0.conv_1.conv.bias\", \"module.conv_0.conv_1.adn.N.weight\", \"module.conv_0.conv_1.adn.N.bias\", \"module.down_1.convs.conv_0.conv.weight\", \"module.down_1.convs.conv_0.conv.bias\", \"module.down_1.convs.conv_0.adn.N.weight\", \"module.down_1.convs.conv_0.adn.N.bias\", \"module.down_1.convs.conv_1.conv.weight\", \"module.down_1.convs.conv_1.conv.bias\", \"module.down_1.convs.conv_1.adn.N.weight\", \"module.down_1.convs.conv_1.adn.N.bias\", \"module.down_2.convs.conv_0.conv.weight\", \"module.down_2.convs.conv_0.conv.bias\", \"module.down_2.convs.conv_0.adn.N.weight\", \"module.down_2.convs.conv_0.adn.N.bias\", \"module.down_2.convs.conv_1.conv.weight\", \"module.down_2.convs.conv_1.conv.bias\", \"module.down_2.convs.conv_1.adn.N.weight\", \"module.down_2.convs.conv_1.adn.N.bias\", \"module.down_3.convs.conv_0.conv.weight\", \"module.down_3.convs.conv_0.conv.bias\", \"module.down_3.convs.conv_0.adn.N.weight\", \"module.down_3.convs.conv_0.adn.N.bias\", \"module.down_3.convs.conv_1.conv.weight\", \"module.down_3.convs.conv_1.conv.bias\", \"module.down_3.convs.conv_1.adn.N.weight\", \"module.down_3.convs.conv_1.adn.N.bias\", \"module.down_4.convs.conv_0.conv.weight\", \"module.down_4.convs.conv_0.conv.bias\", \"module.down_4.convs.conv_0.adn.N.weight\", \"module.down_4.convs.conv_0.adn.N.bias\", \"module.down_4.convs.conv_1.conv.weight\", \"module.down_4.convs.conv_1.conv.bias\", \"module.down_4.convs.conv_1.adn.N.weight\", \"module.down_4.convs.conv_1.adn.N.bias\", \"module.upcat_4.upsample.deconv.weight\", \"module.upcat_4.upsample.deconv.bias\", \"module.upcat_4.convs.conv_0.conv.weight\", \"module.upcat_4.convs.conv_0.conv.bias\", \"module.upcat_4.convs.conv_0.adn.N.weight\", \"module.upcat_4.convs.conv_0.adn.N.bias\", \"module.upcat_4.convs.conv_1.conv.weight\", \"module.upcat_4.convs.conv_1.conv.bias\", \"module.upcat_4.convs.conv_1.adn.N.weight\", \"module.upcat_4.convs.conv_1.adn.N.bias\", \"module.upcat_3.upsample.deconv.weight\", \"module.upcat_3.upsample.deconv.bias\", \"module.upcat_3.convs.conv_0.conv.weight\", \"module.upcat_3.convs.conv_0.conv.bias\", \"module.upcat_3.convs.conv_0.adn.N.weight\", \"module.upcat_3.convs.conv_0.adn.N.bias\", \"module.upcat_3.convs.conv_1.conv.weight\", \"module.upcat_3.convs.conv_1.conv.bias\", \"module.upcat_3.convs.conv_1.adn.N.weight\", \"module.upcat_3.convs.conv_1.adn.N.bias\", \"module.upcat_2.upsample.deconv.weight\", \"module.upcat_2.upsample.deconv.bias\", \"module.upcat_2.convs.conv_0.conv.weight\", \"module.upcat_2.convs.conv_0.conv.bias\", \"module.upcat_2.convs.conv_0.adn.N.weight\", \"module.upcat_2.convs.conv_0.adn.N.bias\", \"module.upcat_2.convs.conv_1.conv.weight\", \"module.upcat_2.convs.conv_1.conv.bias\", \"module.upcat_2.convs.conv_1.adn.N.weight\", \"module.upcat_2.convs.conv_1.adn.N.bias\", \"module.upcat_1.upsample.deconv.weight\", \"module.upcat_1.upsample.deconv.bias\", \"module.upcat_1.convs.conv_0.conv.weight\", \"module.upcat_1.convs.conv_0.conv.bias\", \"module.upcat_1.convs.conv_0.adn.N.weight\", \"module.upcat_1.convs.conv_0.adn.N.bias\", \"module.upcat_1.convs.conv_1.conv.weight\", \"module.upcat_1.convs.conv_1.conv.bias\", \"module.upcat_1.convs.conv_1.adn.N.weight\", \"module.upcat_1.convs.conv_1.adn.N.bias\", \"module.final_conv.weight\", \"module.final_conv.bias\". " + ] + } + ], + "source": [ + "c2 = infer(\n", + " t1c_file=t1c_file, segmentation_file=segmentation_file, tta=tta, verbosity=verbosity\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'c2' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/marcelrosier/Projects/helmholtz/AURORA/tester.ipynb Cell 3\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m c[\u001b[39m'\u001b[39m\u001b[39mmodel_state\u001b[39m\u001b[39m'\u001b[39m]\u001b[39m.\u001b[39mkeys() \u001b[39m==\u001b[39m c2[\u001b[39m'\u001b[39m\u001b[39mmodel_state\u001b[39m\u001b[39m'\u001b[39m]\u001b[39m.\u001b[39mkeys()\n", + "\u001b[0;31mNameError\u001b[0m: name 'c2' is not defined" + ] + } + ], + "source": [ + "c[\"model_state\"].keys() == c2[\"model_state\"].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "test = [\n", + " \"module.conv_0.conv_0.conv.weight\",\n", + " \"module.conv_0.conv_0.conv.bias\",\n", + " \"module.conv_0.conv_0.adn.N.weight\",\n", + " \"module.conv_0.conv_0.adn.N.bias\",\n", + " \"module.conv_0.conv_1.conv.weight\",\n", + " \"module.conv_0.conv_1.conv.bias\",\n", + " \"module.conv_0.conv_1.adn.N.weight\",\n", + " \"module.conv_0.conv_1.adn.N.bias\",\n", + " \"module.down_1.convs.conv_0.conv.weight\",\n", + " \"module.down_1.convs.conv_0.conv.bias\",\n", + " \"module.down_1.convs.conv_0.adn.N.weight\",\n", + " \"module.down_1.convs.conv_0.adn.N.bias\",\n", + " \"module.down_1.convs.conv_1.conv.weight\",\n", + " \"module.down_1.convs.conv_1.conv.bias\",\n", + " \"module.down_1.convs.conv_1.adn.N.weight\",\n", + " \"module.down_1.convs.conv_1.adn.N.bias\",\n", + " \"module.down_2.convs.conv_0.conv.weight\",\n", + " \"module.down_2.convs.conv_0.conv.bias\",\n", + " \"module.down_2.convs.conv_0.adn.N.weight\",\n", + " \"module.down_2.convs.conv_0.adn.N.bias\",\n", + " \"module.down_2.convs.conv_1.conv.weight\",\n", + " \"module.down_2.convs.conv_1.conv.bias\",\n", + " \"module.down_2.convs.conv_1.adn.N.weight\",\n", + " \"module.down_2.convs.conv_1.adn.N.bias\",\n", + " \"module.down_3.convs.conv_0.conv.weight\",\n", + " \"module.down_3.convs.conv_0.conv.bias\",\n", + " \"module.down_3.convs.conv_0.adn.N.weight\",\n", + " \"module.down_3.convs.conv_0.adn.N.bias\",\n", + " \"module.down_3.convs.conv_1.conv.weight\",\n", + " \"module.down_3.convs.conv_1.conv.bias\",\n", + " \"module.down_3.convs.conv_1.adn.N.weight\",\n", + " \"module.down_3.convs.conv_1.adn.N.bias\",\n", + " \"module.down_4.convs.conv_0.conv.weight\",\n", + " \"module.down_4.convs.conv_0.conv.bias\",\n", + " \"module.down_4.convs.conv_0.adn.N.weight\",\n", + " \"module.down_4.convs.conv_0.adn.N.bias\",\n", + " \"module.down_4.convs.conv_1.conv.weight\",\n", + " \"module.down_4.convs.conv_1.conv.bias\",\n", + " \"module.down_4.convs.conv_1.adn.N.weight\",\n", + " \"module.down_4.convs.conv_1.adn.N.bias\",\n", + " \"module.upcat_4.upsample.deconv.weight\",\n", + " \"module.upcat_4.upsample.deconv.bias\",\n", + " \"module.upcat_4.convs.conv_0.conv.weight\",\n", + " \"module.upcat_4.convs.conv_0.conv.bias\",\n", + " \"module.upcat_4.convs.conv_0.adn.N.weight\",\n", + " \"module.upcat_4.convs.conv_0.adn.N.bias\",\n", + " \"module.upcat_4.convs.conv_1.conv.weight\",\n", + " \"module.upcat_4.convs.conv_1.conv.bias\",\n", + " \"module.upcat_4.convs.conv_1.adn.N.weight\",\n", + " \"module.upcat_4.convs.conv_1.adn.N.bias\",\n", + " \"module.upcat_3.upsample.deconv.weight\",\n", + " \"module.upcat_3.upsample.deconv.bias\",\n", + " \"module.upcat_3.convs.conv_0.conv.weight\",\n", + " \"module.upcat_3.convs.conv_0.conv.bias\",\n", + " \"module.upcat_3.convs.conv_0.adn.N.weight\",\n", + " \"module.upcat_3.convs.conv_0.adn.N.bias\",\n", + " \"module.upcat_3.convs.conv_1.conv.weight\",\n", + " \"module.upcat_3.convs.conv_1.conv.bias\",\n", + " \"module.upcat_3.convs.conv_1.adn.N.weight\",\n", + " \"module.upcat_3.convs.conv_1.adn.N.bias\",\n", + " \"module.upcat_2.upsample.deconv.weight\",\n", + " \"module.upcat_2.upsample.deconv.bias\",\n", + " \"module.upcat_2.convs.conv_0.conv.weight\",\n", + " \"module.upcat_2.convs.conv_0.conv.bias\",\n", + " \"module.upcat_2.convs.conv_0.adn.N.weight\",\n", + " \"module.upcat_2.convs.conv_0.adn.N.bias\",\n", + " \"module.upcat_2.convs.conv_1.conv.weight\",\n", + " \"module.upcat_2.convs.conv_1.conv.bias\",\n", + " \"module.upcat_2.convs.conv_1.adn.N.weight\",\n", + " \"module.upcat_2.convs.conv_1.adn.N.bias\",\n", + " \"module.upcat_1.upsample.deconv.weight\",\n", + " \"module.upcat_1.upsample.deconv.bias\",\n", + " \"module.upcat_1.convs.conv_0.conv.weight\",\n", + " \"module.upcat_1.convs.conv_0.conv.bias\",\n", + " \"module.upcat_1.convs.conv_0.adn.N.weight\",\n", + " \"module.upcat_1.convs.conv_0.adn.N.bias\",\n", + " \"module.upcat_1.convs.conv_1.conv.weight\",\n", + " \"module.upcat_1.convs.conv_1.conv.bias\",\n", + " \"module.upcat_1.convs.conv_1.adn.N.weight\",\n", + " \"module.upcat_1.convs.conv_1.adn.N.bias\",\n", + " \"module.final_conv.weight\",\n", + " \"module.final_conv.bias\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "for k in c[\"model_state\"].keys():\n", + " if not k in test:\n", + " print(k)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "helm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_inferer.py b/tests/test_inferer.py new file mode 100644 index 0000000..6c7abae --- /dev/null +++ b/tests/test_inferer.py @@ -0,0 +1,139 @@ +from pathlib import Path +from unittest.mock import patch + +import nibabel as nib +import numpy as np +import pytest +import torch + +from brainles_aurora.inferer.constants import InferenceMode +from brainles_aurora.inferer.dataclasses import AuroraInfererConfig +from brainles_aurora.inferer.inferer import AuroraInferer, AuroraGPUInferer + + +class TestAuroraInferer: + @pytest.fixture + def t1_path(self): + return "example_data/BraTS-MET-00110-000-t1n.nii.gz" + + @pytest.fixture + def t1c_path(self): + return "example_data/BraTS-MET-00110-000-t1c.nii.gz" + + @pytest.fixture + def t2_path(self): + return "example_data/BraTS-MET-00110-000-t2w.nii.gz" + + @pytest.fixture + def fla_path(self): + return "example_data/BraTS-MET-00110-000-t2f.nii.gz" + + @pytest.fixture + def mock_config(self, t1_path, t1c_path, t2_path, fla_path): + return AuroraInfererConfig(t1=t1_path, t1c=t1c_path, t2=t2_path, fla=fla_path) + + @pytest.fixture + def mock_inferer(self, mock_config): + return AuroraInferer(config=mock_config) + + @pytest.fixture + def load_np_from_nifti(self): + def _load_np_from_nifti(path): + return nib.load(path).get_fdata() + + return _load_np_from_nifti + + def test_validate_images(self, mock_config): + inferer = AuroraInferer(config=mock_config) + images = inferer._validate_images() + assert len(images) == 4 + assert all(isinstance(img, Path) for img in images) + + def test_validate_images_file_not_found(self, mock_config): + mock_config.t1 = "invalid_path.nii.gz" + with pytest.raises(FileNotFoundError): + _ = AuroraInferer(config=mock_config) + # called internally in __init__ + # inferer._validate_images() + + def test_validate_images_different_types(self, mock_config, load_np_from_nifti): + mock_config.t1 = load_np_from_nifti(mock_config.t1) + with pytest.raises(AssertionError): + _ = AuroraInferer(config=mock_config) + # called internally in __init__ + # inferer._validate_images() + + def test_validate_images_no_inputs(self, mock_config, load_np_from_nifti): + mock_config.t1 = None + mock_config.t1c = None + mock_config.t2 = None + mock_config.fla = None + with pytest.raises(AssertionError): + _ = AuroraInferer(config=mock_config) + # called internally in __init__ + # inferer._validate_images() + + def test_determine_inference_mode(self, mock_config): + inferer = AuroraInferer(config=mock_config) + mode = inferer._determine_inference_mode() + assert isinstance(mode, InferenceMode) + + def test_determine_inference_mode_not_implemented(self, mock_config): + mock_validated_images = [ + None, + None, + None, + None, + ] # set all to None to raise NotImplementedError + with pytest.raises(NotImplementedError), patch( + "brainles_aurora.inferer.inferer.AuroraInferer._validate_images", + return_value=mock_validated_images, + ): + inferer = AuroraInferer(config=mock_config) + # called internally in __init__ + # inferer._determine_inference_mode() + + def test_get_data_loader(self, mock_config): + inferer = AuroraInferer(config=mock_config) + data_loader = inferer._get_data_loader() + assert isinstance(data_loader, torch.utils.data.DataLoader) + + def test_get_model(self, mock_config): + inferer = AuroraInferer(config=mock_config) + model = inferer._get_model() + assert isinstance(model, torch.nn.Module) + + def test_setup_logger(self, mock_config): + inferer = AuroraInferer(config=mock_config) + assert inferer.log_path is not None + assert inferer.output_folder.exists() + assert inferer.output_folder.is_dir() + + def test_infer(self, mock_config): + inferer = AuroraInferer(config=mock_config) + with patch.object(inferer, "_sliding_window_inference", return_value=None): + inferer.infer() + + def test_configure_device(self, mock_config): + inferer = AuroraInferer(config=mock_config) + device = inferer._configure_device() + assert device == torch.device("cpu") + + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Skipping GPU device test since cuda is not available", + ) + def test_configure_device_gpu(self, mock_config): + inferer = AuroraGPUInferer(config=mock_config) + device = inferer._configure_device() + assert device == torch.device("cuda") + + def test_get_model(self, mock_config): + inferer = AuroraInferer(config=mock_config) + model = inferer._get_model() + assert isinstance(model, torch.nn.Module) + + def test_get_data_loader(self, mock_config): + inferer = AuroraInferer(config=mock_config) + data_loader = inferer._get_data_loader() + assert isinstance(data_loader, torch.utils.data.DataLoader)