Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1378 Add output_dtype to NiftiSaver and SegmentationSaver #1394

Merged
merged 2 commits into from
Dec 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: Optional[np.dtype] = np.float64,
output_dtype: Optional[np.dtype] = np.float32,
) -> None:
"""
Args:
Expand All @@ -57,6 +58,7 @@ def __init__(
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
output_dtype: data type for saving data. Defaults to ``np.float32``.
"""
self.output_dir = output_dir
self.output_postfix = output_postfix
Expand All @@ -66,6 +68,7 @@ def __init__(
self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)
self.align_corners = align_corners
self.dtype = dtype
self.output_dtype = output_dtype
self._data_index = 0

def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
Expand Down Expand Up @@ -118,6 +121,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
padding_mode=self.padding_mode,
align_corners=self.align_corners,
dtype=self.dtype,
output_dtype=self.output_dtype,
)

def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
Expand Down
10 changes: 7 additions & 3 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(
mode: Union[GridSampleMode, InterpolateMode, str] = "nearest",
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
scale: Optional[int] = None,
dtype: Optional[np.dtype] = None,
dtype: Optional[np.dtype] = np.float64,
output_dtype: Optional[np.dtype] = np.float32,
batch_transform: Callable = lambda x: x,
output_transform: Callable = lambda x: x,
name: Optional[str] = None,
Expand Down Expand Up @@ -69,8 +70,10 @@ def __init__(
scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling
[0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
It's used for PNG format only.
dtype: convert the image data to save to this data type.
If None, keep the original type of data. It's used for Nifti format only.
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``, it's used for Nifti format only.
output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only.
batch_transform: a callable that is used to transform the
ignite.engine.batch into expected format to extract the meta_data dictionary.
output_transform: a callable that is used to transform the
Expand All @@ -90,6 +93,7 @@ def __init__(
mode=GridSampleMode(mode),
padding_mode=padding_mode,
dtype=dtype,
output_dtype=output_dtype,
)
elif output_ext == ".png":
self.saver = PNGSaver(
Expand Down