From 7eed04bc480d3483b5c578a15deae63b95852f3c Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Mon, 29 Jan 2024 14:17:38 +0100 Subject: [PATCH] - rm output mode paramter (will alwas return numpy data and save files if paths provided) --- brainles_aurora/inferer/dataclasses.py | 3 --- brainles_aurora/inferer/inferer.py | 36 ++++++++++---------------- segmentation_test.py | 2 -- 3 files changed, 14 insertions(+), 27 deletions(-) diff --git a/brainles_aurora/inferer/dataclasses.py b/brainles_aurora/inferer/dataclasses.py index be2f9b3..09626bc 100644 --- a/brainles_aurora/inferer/dataclasses.py +++ b/brainles_aurora/inferer/dataclasses.py @@ -11,11 +11,9 @@ class BaseConfig: """Base configuration for the Aurora model inferer. Attributes: - output_mode (DataMode, optional): Output mode for the inference results. Defaults to DataMode.NIFTI_FILE. log_level (int | str, optional): Logging level. Defaults to logging.INFO. """ - output_mode: DataMode = DataMode.NIFTI_FILE log_level: int | str = logging.INFO @@ -24,7 +22,6 @@ class AuroraInfererConfig(BaseConfig): """Configuration for the Aurora model inferer. Attributes: - output_mode (DataMode, optional): Output mode for the inference results. Defaults to DataMode.NIFTI_FILE. log_level (int | str, optional): Logging level. Defaults to logging.INFO. tta (bool, optional): Whether to apply test-time augmentations. Defaults to True. sliding_window_batch_size (int, optional): Batch size for sliding window inference. Defaults to 1. diff --git a/brainles_aurora/inferer/inferer.py b/brainles_aurora/inferer/inferer.py index 9291794..d8259a1 100644 --- a/brainles_aurora/inferer/inferer.py +++ b/brainles_aurora/inferer/inferer.py @@ -427,11 +427,11 @@ def _post_process( Output.METASTASIS_NETWORK: enhancing_out, } - def _sliding_window_inference(self) -> None | Dict[str, np.ndarray]: + def _sliding_window_inference(self) -> Dict[str, np.ndarray]: """Perform sliding window inference using monai.inferers.SlidingWindowInferer. Returns: - None | Dict[str, np.ndarray]: Post-processed data if output_mode is NUMPY, otherwise the data is saved as a niftis and None is returned. + Dict[str, np.ndarray]: Post-processed data """ inferer = SlidingWindowInferer( roi_size=self.config.crop_size, # = patch_size @@ -461,15 +461,14 @@ def _sliding_window_inference(self) -> None | Dict[str, np.ndarray]: postprocessed_data = self._post_process( onehot_model_outputs_CHWD=outputs, ) - if self.config.output_mode == DataMode.NUMPY: - self.log.info( - "Returning post-processed data as Dict of Numpy arrays" - ) - return postprocessed_data - else: + + # save data to fie if paths are provided + if any(self.output_file_mapping.values()): self.log.info("Saving post-processed data as NIFTI files") self._save_as_nifti(postproc_data=postprocessed_data) - return None + + self.log.info("Returning post-processed data as Dict of Numpy arrays") + return postprocessed_data def _configure_device(self) -> torch.device: """Configure the device for inference. @@ -511,7 +510,7 @@ def infer( log_file (str | Path | None, optional): _description_. Defaults to None. Returns: - Dict[str, np.ndarray] | None: Post-processed data if output_mode is NUMPY, otherwise the data is saved as a niftis and None is returned. + Dict[str, np.ndarray]: Post-processed data. """ # setup logger for inference run if log_file: @@ -545,18 +544,11 @@ def infer( self.data_loader = self._get_data_loader() # setup output file paths - if self.config.output_mode == DataMode.NIFTI_FILE: - # TODO add error handling to ensure file extensions present - if not segmentation_file: - default_segmentation_path = os.path.abspath("./segmentation.nii.gz") - self.log.warning( - f"No segmentation file name provided, using default path: {default_segmentation_path}" - ) - self.output_file_mapping = { - Output.SEGMENTATION: segmentation_file or default_segmentation_path, - Output.WHOLE_NETWORK: whole_tumor_unbinarized_floats_file, - Output.METASTASIS_NETWORK: metastasis_unbinarized_floats_file, - } + self.output_file_mapping = { + Output.SEGMENTATION: segmentation_file, + Output.WHOLE_NETWORK: whole_tumor_unbinarized_floats_file, + Output.METASTASIS_NETWORK: metastasis_unbinarized_floats_file, + } ######## self.log.info(f"Running inference on device := {self.device}") diff --git a/segmentation_test.py b/segmentation_test.py index 7214b95..2f1395c 100644 --- a/segmentation_test.py +++ b/segmentation_test.py @@ -76,7 +76,6 @@ def cpu_nifti(): def gpu_np(): config = AuroraInfererConfig( tta=False, - output_mode=DataMode.NUMPY, ) # disable tta for faster inference in this showcase # If you don-t have a GPU that supports CUDA use the CPU version: AuroraInferer(config=config) @@ -95,7 +94,6 @@ def gpu_output_np(): t1c=load_np_from_nifti(t1c), t2=load_np_from_nifti(t2), fla=load_np_from_nifti(fla), - output_mode=DataMode.NUMPY, ) inferer = AuroraGPUInferer( config=config,