diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 6575d31251..a6b8a48d26 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -37,6 +37,7 @@ def __init__( padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: Optional[np.dtype] = np.float64, + output_dtype: Optional[np.dtype] = np.float32, ) -> None: """ Args: @@ -57,6 +58,7 @@ def __init__( dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + output_dtype: data type for saving data. Defaults to ``np.float32``. """ self.output_dir = output_dir self.output_postfix = output_postfix @@ -66,6 +68,7 @@ def __init__( self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) self.align_corners = align_corners self.dtype = dtype + self.output_dtype = output_dtype self._data_index = 0 def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: @@ -118,6 +121,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] padding_mode=self.padding_mode, align_corners=self.align_corners, dtype=self.dtype, + output_dtype=self.output_dtype, ) def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 444768d555..4da2c4394f 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -38,7 +38,8 @@ def __init__( mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, - dtype: Optional[np.dtype] = None, + dtype: Optional[np.dtype] = np.float64, + output_dtype: Optional[np.dtype] = np.float32, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, @@ -69,8 +70,10 @@ def __init__( scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. - dtype: convert the image data to save to this data type. - If None, keep the original type of data. It's used for Nifti format only. + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``, it's used for Nifti format only. + output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform: a callable that is used to transform the @@ -90,6 +93,7 @@ def __init__( mode=GridSampleMode(mode), padding_mode=padding_mode, dtype=dtype, + output_dtype=output_dtype, ) elif output_ext == ".png": self.saver = PNGSaver(