Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/2231 transforms #2278

Closed
wants to merge 209 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
209 commits
Select commit Hold shift + click to select a range
7d5e639
torch default for Flip
rijobro May 27, 2021
b01d6bc
correct doc
rijobro May 27, 2021
218a93c
return type
rijobro May 27, 2021
7911ee5
typing
rijobro May 27, 2021
1d38104
use types instead of bools
rijobro May 27, 2021
9a6a1ad
Merge remote-tracking branch 'MONAI/dev' into torch_default_flip
rijobro May 27, 2021
6ade9c3
update var name
rijobro May 27, 2021
427ee2f
abstract common functionality
rijobro May 27, 2021
5105508
torch default
rijobro May 28, 2021
ce26166
Merge branch 'torch_default' into feature/2231-transforms
rijobro May 28, 2021
5680d9e
to_numpy, to_tensor, transpose
rijobro May 28, 2021
d921834
add todotransform
rijobro May 28, 2021
dc68510
Merge branch 'dev' into feature/2231-transforms
wyli May 30, 2021
59a7ce7
full ci/cd for feature branches
wyli May 30, 2021
4662522
more transforms
rijobro Jun 1, 2021
257952b
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jun 1, 2021
9a6901c
SpatialCrop, CenterSpatialCrop, AsDiscrete, RepeatChannel, RemoveRepe…
rijobro Jun 1, 2021
494db1d
MaskIntensity
rijobro Jun 1, 2021
ada481f
gibbsnoise, and restore device
rijobro Jun 1, 2021
b9d91ea
Merge branch 'dev' into feature/2231-transforms
wyli Jun 2, 2021
12d8ac7
rotate90, randrotate90
rijobro Jun 2, 2021
5f8c1c7
rand_gibbs_noise
rijobro Jun 2, 2021
0ec1ff1
crop_foreground
rijobro Jun 2, 2021
207d18f
format
rijobro Jun 2, 2021
1b578a2
torch_seed and rician_noise
rijobro Jun 2, 2021
81bd74f
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jun 2, 2021
c0aad87
format
rijobro Jun 2, 2021
9d75e05
seed as int
rijobro Jun 2, 2021
20283d0
inverses
rijobro Jun 2, 2021
754308b
gaussian_smooth and rand_gaussian_smooth
rijobro Jun 2, 2021
963fc48
gibbs noise
rijobro Jun 2, 2021
7cea565
gaussian sharpen
rijobro Jun 2, 2021
e5da98c
DetectEnvelope, test transforms w/ CUDA
rijobro Jun 3, 2021
0ce6d9d
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jun 3, 2021
618c4b3
SavitzkyGolaySmooth
rijobro Jun 3, 2021
53b381a
scale intensity
rijobro Jun 3, 2021
db2ab8e
rand scale intensity
rijobro Jun 3, 2021
0239a06
code format
rijobro Jun 3, 2021
c3880df
test smart cache dataset returns torch
rijobro Jun 3, 2021
6a26889
code format
rijobro Jun 3, 2021
bfe78b3
Merge branch 'dev' into feature/2231-transforms
rijobro Jun 14, 2021
5513e5d
post merge changes
rijobro Jun 14, 2021
03b825b
StdShiftIntensity and RandStdShiftIntensity
rijobro Jun 14, 2021
9739ba5
TEST_NDARRAYS reduce code duplication
rijobro Jun 14, 2021
a270fba
dtype_convert and testing
rijobro Jun 14, 2021
c95f059
spatial pad, border pad, divisible pad
rijobro Jun 14, 2021
0746282
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jun 14, 2021
12d514d
update convert dtype
rijobro Jun 14, 2021
bc78766
update convert_dtype
rijobro Jun 14, 2021
5208e8d
gibbs_noised returns same type as input
rijobro Jun 14, 2021
80aa083
rand_rician_noised
rijobro Jun 14, 2021
cfb1ff0
nifti writer works with torch.tensor
rijobro Jun 15, 2021
db5a094
test_image_dataset check np or torch dtype
rijobro Jun 15, 2021
2d9258b
guassian smooth needs float type
rijobro Jun 15, 2021
8764a7b
fix tests
rijobro Jun 15, 2021
50244c4
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jun 15, 2021
2817559
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jun 16, 2021
ab9457e
fix gaussiansharpen
rijobro Jun 16, 2021
b2699fc
LoadImaged allows torch.tensor
rijobro Jun 16, 2021
e366205
allow conversion from ITK
rijobro Jun 16, 2021
02a87b7
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jun 28, 2021
8c622d2
post-merge updates
rijobro Jun 28, 2021
8908a99
fix to_tensor 0 dim
rijobro Jun 28, 2021
4473ee5
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 2, 2021
5b382e3
mypy
rijobro Jul 2, 2021
0de7329
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 2, 2021
b9928ad
Merge branch 'dev' into feature/2231-transforms
rijobro Jul 5, 2021
e2926b8
add temp test script
rijobro Jul 5, 2021
100e69e
update
rijobro Jul 5, 2021
7d889d6
as_discrete test numpy and torch
rijobro Jul 5, 2021
8dbe874
padder np_kwargs
rijobro Jul 5, 2021
4e843f5
fix borderpad
rijobro Jul 5, 2021
ee0f3fa
fix test_invertd
rijobro Jul 5, 2021
94dc56b
fix gaussian smooth
rijobro Jul 6, 2021
b8c52dc
use torch.tensor
rijobro Jul 6, 2021
ed8a42f
zoom and randzoom
rijobro Jul 6, 2021
359fca1
code format
rijobro Jul 6, 2021
e50ae4c
Merge branch 'dev' into feature/2231-transforms
rijobro Jul 6, 2021
4bb16ca
fix std shift intensity test
rijobro Jul 6, 2021
86399b3
fix test_rotate90
rijobro Jul 6, 2021
bdc2687
fix rand_gaussian_sharpen
rijobro Jul 6, 2021
fd82b7e
fix test_detect_envelope
rijobro Jul 6, 2021
32cc1f9
fix test_rand_std_shift_intensity
rijobro Jul 6, 2021
a2b9bd8
fix test_scale_intensity
rijobro Jul 6, 2021
af28b99
fix test_savitzky_golay_smooth
rijobro Jul 6, 2021
3c35c9f
fix test_mask_intensity
rijobro Jul 6, 2021
f8f4b23
fix test_rand_scale_intensity
rijobro Jul 6, 2021
52c8ced
fix test_spatial_pad
rijobro Jul 6, 2021
ed8a82b
code format
rijobro Jul 6, 2021
8b53da5
set num workers=0 for mac
rijobro Jul 6, 2021
e3b243d
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 6, 2021
c83e80b
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 7, 2021
b82c6f1
fix test_divisible_pad
rijobro Jul 7, 2021
4c71aed
fix test_gaussian_sharpen
rijobro Jul 7, 2021
87c6d83
fix test_rand_rotate90
rijobro Jul 7, 2021
c1fe71b
fix test_crop_foreground
rijobro Jul 7, 2021
61d7510
fix gibbs noise
rijobro Jul 7, 2021
f0afcce
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 7, 2021
801ff55
improve docstring
rijobro Jul 7, 2021
169f7b8
rotate as torch
rijobro Jul 7, 2021
c2b0775
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 12, 2021
c9dd886
post merge fixes
rijobro Jul 12, 2021
f0e396c
fix test_to_numpy
rijobro Jul 12, 2021
82e944d
skip torch tests if no torch.fft
rijobro Jul 12, 2021
279579d
same for noised
rijobro Jul 12, 2021
595087e
fix test_as_channel_firstd
rijobro Jul 12, 2021
a6fe6f1
fix test_gibbs_noise(d)
rijobro Jul 12, 2021
7992f29
channel first and last
rijobro Jul 12, 2021
2cf20db
isort/mypy
rijobro Jul 12, 2021
0700881
require fftshift
rijobro Jul 13, 2021
4d03ff8
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 13, 2021
fc4cc7f
num_workers=0 on windows
rijobro Jul 13, 2021
4d7f1e6
improve testcenterspatialcrop
rijobro Jul 13, 2021
0f2a6c1
pickle generator
rijobro Jul 13, 2021
a186710
rician gpu test
rijobro Jul 13, 2021
6843438
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 13, 2021
1e8dfaa
code format
rijobro Jul 13, 2021
3181f4c
flake8
rijobro Jul 13, 2021
4822865
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 14, 2021
06f7a6e
fix torch seed
rijobro Jul 14, 2021
289a7ce
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 14, 2021
f7aeb07
remove duplicate class
rijobro Jul 14, 2021
31885ea
remove duplicate class
rijobro Jul 14, 2021
5071afe
ClassesToIndices, RandCropByLabelClasses
rijobro Jul 14, 2021
756d0c2
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 14, 2021
33e0818
post merge fix
rijobro Jul 14, 2021
7b84d43
can't pickle torch.Generator
rijobro Jul 15, 2021
f8c4440
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 15, 2021
b9d0a24
support pytorch==1.6
rijobro Jul 15, 2021
c687d5f
spacing/spacingd
rijobro Jul 15, 2021
f75bb51
pytorch <= 1.5 support
rijobro Jul 15, 2021
79157fa
ignore
rijobro Jul 15, 2021
f659b97
fix unit tests
rijobro Jul 16, 2021
be9dab8
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 16, 2021
4447c1a
remove torch
rijobro Jul 16, 2021
1ad0e4d
resample, affine, affinegrid, elastic2d, elastic3d
rijobro Jul 16, 2021
7abf9e6
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 19, 2021
d71a91d
numpy/pytorch compatibility
rijobro Jul 19, 2021
9816ffe
resize, resized, png_writer
rijobro Jul 19, 2021
8d5446c
orientation
rijobro Jul 19, 2021
a3e235d
TTA
rijobro Jul 19, 2021
5f0d6a5
add_coordinate_channels
rijobro Jul 19, 2021
b1bb9a1
no more todotransform
rijobro Jul 19, 2021
b7afb91
remove pre_conv_data and post_conv_data
rijobro Jul 19, 2021
5a58249
deepgrow update
rijobro Jul 19, 2021
5a14c2d
no TorchOrNumpyTransform
rijobro Jul 19, 2021
2b82e87
AddExtremePointsChannel
rijobro Jul 19, 2021
912b595
probnms
rijobro Jul 19, 2021
199a374
save image
rijobro Jul 19, 2021
731fe9f
roc fix
rijobro Jul 19, 2021
e1c10d1
vote ensemble
rijobro Jul 19, 2021
a9a79b3
threshold intensity
rijobro Jul 19, 2021
5661271
generate_spatial_bounding_box, BoundingRect, BoundingRectd
rijobro Jul 19, 2021
317eaf3
transpose, transposed
rijobro Jul 19, 2021
b9269a6
get_extreme
rijobro Jul 19, 2021
82ecaa0
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 19, 2021
1df2dcb
update gibbs
rijobro Jul 20, 2021
e9b4513
as discrete
rijobro Jul 20, 2021
045d418
code format
rijobro Jul 20, 2021
b8bcb94
activations
rijobro Jul 20, 2021
0aeafc9
float -> torch.float32
rijobro Jul 20, 2021
3e58f25
line
rijobro Jul 20, 2021
bc67dbe
gaussian noise(d)
rijobro Jul 20, 2021
c4fa7e8
update rb script
rijobro Jul 20, 2021
38c5d48
labeltocontour
rijobro Jul 20, 2021
da2b3a3
normalize intensity
rijobro Jul 20, 2021
b50fcce
crop foreground
rijobro Jul 20, 2021
73e3e70
LabelToMask[d]
rijobro Jul 20, 2021
e1d81ed
RandCropByPosNegLabel
rijobro Jul 20, 2021
d70f319
center scale crop
rijobro Jul 21, 2021
d1fa1d7
resizewithpadorcrop
rijobro Jul 21, 2021
87264a2
randweightedcrop
rijobro Jul 21, 2021
37da23e
merge dev into Feature/2231 transforms (#2640)
wyli Jul 22, 2021
60f35fc
RandSpatialCrop, RandSpatialCropSamples
rijobro Jul 22, 2021
dcd226b
k space spike
rijobro Jul 22, 2021
5cd4319
rand bias
rijobro Jul 22, 2021
ae673c9
vote ensemble
rijobro Jul 22, 2021
fec7286
Merge branch 'dev' into feature/2231-transforms
rijobro Jul 23, 2021
9264f1b
post merge fixes
rijobro Jul 23, 2021
af4252c
one hot for np as well
rijobro Jul 23, 2021
f3e9b6a
loadsa stuff
rijobro Jul 23, 2021
4d12313
Merge branch 'dev' into feature/2231-transforms
rijobro Jul 28, 2021
57fc771
fg_bg_to_indices
rijobro Jul 28, 2021
b940325
RandCoarseDropout
rijobro Jul 28, 2021
7f6f05f
ConvertToMultiChannelBasedOnBratsClasses
rijobro Jul 28, 2021
f0419d0
PadCollation
rijobro Jul 28, 2021
ed6903a
mypy
rijobro Jul 28, 2021
9a7cf96
fix 1d SpatialCrop
rijobro Jul 28, 2021
9db3256
fix randbiasfield
rijobro Jul 28, 2021
ecbb200
fix KSpaceSpikeNoise
rijobro Jul 28, 2021
c3323d5
fix RandBiasFieldd
rijobro Jul 28, 2021
e645a3f
fix generate_pos_neg_label_crop_centers
rijobro Jul 28, 2021
af49fd7
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 28, 2021
e31bfb0
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 29, 2021
293f196
fix generate_spatial_bounding_box
rijobro Jul 29, 2021
d03911b
fix resize
rijobro Jul 29, 2021
40fa288
separate one_hot fns for np and torch (allows for torchscript)
rijobro Jul 29, 2021
fd1af06
fix various tests
rijobro Jul 29, 2021
be9ce91
fix one_hot
rijobro Jul 29, 2021
6401533
fix one_hot 2
rijobro Jul 29, 2021
c303989
fix test_invers
rijobro Jul 29, 2021
34776d8
fix pt1.6
rijobro Jul 30, 2021
0ead566
fix
rijobro Jul 30, 2021
6d18280
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 30, 2021
446d7ea
pt1.7 fix
rijobro Jul 30, 2021
cb1f047
fix torch1.7
rijobro Jul 30, 2021
8d5af04
Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms
rijobro Jul 30, 2021
f30f72e
post merge fix
rijobro Jul 30, 2021
228e3b4
get number of conversions
rijobro Aug 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions monai/apps/deepgrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd
from monai.utils import GridSampleMode
from monai.utils.misc import convert_data_type


def create_dataset(
Expand Down Expand Up @@ -82,7 +83,7 @@ def create_dataset(
if not len(datalist):
raise ValueError("Input datalist is empty")

transforms = _default_transforms(image_key, label_key, pixdim) if transforms is None else transforms
transforms = transforms or _default_transforms(image_key, label_key, pixdim)
new_datalist = []
for idx in range(len(datalist)):
if limit and idx >= limit:
Expand Down Expand Up @@ -133,25 +134,34 @@ def _default_transforms(image_key, label_key, pixdim):


def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
if vol_image is not None:
vol_image_np, *_ = convert_data_type(vol_image, np.ndarray)
else:
vol_image_np = vol_image
if vol_label is not None:
vol_label_np, *_ = convert_data_type(vol_label, np.ndarray)
else:
vol_label_np = vol_label

data_list = []

if len(vol_image.shape) == 4:
if len(vol_image_np.shape) == 4:
logging.info(
"4D-Image, pick only first series; Image: {}; Label: {}".format(
vol_image.shape, vol_label.shape if vol_label is not None else None
vol_image_np.shape, vol_label_np.shape if vol_label_np is not None else None
)
)
vol_image = vol_image[0]
vol_image = np.moveaxis(vol_image, -1, 0)
vol_image_np = vol_image_np[0]
vol_image_np = np.moveaxis(vol_image_np, -1, 0)

image_count = 0
label_count = 0
unique_labels_count = 0
for sid in range(vol_image.shape[0]):
image = vol_image[sid, ...]
label = vol_label[sid, ...] if vol_label is not None else None
for sid in range(vol_image_np.shape[0]):
image = vol_image_np[sid, ...]
label = vol_label_np[sid, ...] if vol_label_np is not None else None

if vol_label is not None and np.sum(label) == 0:
if vol_label_np is not None and np.sum(label) == 0:
continue

image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid)
Expand All @@ -163,7 +173,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
image_count += 1

# Test Data
if vol_label is None:
if vol_label_np is None:
data_list.append(
{
"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file,
Expand Down Expand Up @@ -200,9 +210,9 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
logging.info(
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
vol_idx,
vol_image.shape,
vol_image_np.shape,
image_count,
vol_label.shape if vol_label is not None else None,
vol_label_np.shape if vol_label_np is not None else None,
label_count,
unique_labels_count,
)
Expand All @@ -211,16 +221,25 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):


def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
if vol_image is not None:
vol_image_np, *_ = convert_data_type(vol_image, np.ndarray)
else:
vol_image_np = vol_image
if vol_label is not None:
vol_label_np, *_ = convert_data_type(vol_label, np.ndarray)
else:
vol_label_np = vol_label

data_list = []

if len(vol_image.shape) == 4:
if len(vol_image_np.shape) == 4:
logging.info(
"4D-Image, pick only first series; Image: {}; Label: {}".format(
vol_image.shape, vol_label.shape if vol_label is not None else None
vol_image_np.shape, vol_label_np.shape if vol_label_np is not None else None
)
)
vol_image = vol_image[0]
vol_image = np.moveaxis(vol_image, -1, 0)
vol_image_np = vol_image_np[0]
vol_image_np = np.moveaxis(vol_image_np, -1, 0)

image_count = 0
label_count = 0
Expand All @@ -231,19 +250,19 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
image_file += ".npy"

os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True)
np.save(image_file, vol_image)
np.save(image_file, vol_image_np)
image_count += 1

# Test Data
if vol_label is None:
if vol_label_np is None:
data_list.append(
{
"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file,
}
)
else:
# For all Labels
unique_labels = np.unique(vol_label.flatten())
unique_labels = np.unique(vol_label_np.flatten())
unique_labels = unique_labels[unique_labels != 0]
unique_labels_count = max(unique_labels_count, len(unique_labels))

Expand All @@ -252,7 +271,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
label_file = os.path.join(dataset_dir, "labels", label_file_prefix)
label_file += ".npy"

curr_label = (vol_label == idx).astype(np.float32)
curr_label = (vol_label_np == idx).astype(np.float32)
os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True)
np.save(label_file, curr_label)

Expand All @@ -271,9 +290,9 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
logging.info(
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
vol_idx,
vol_image.shape,
vol_image_np.shape,
image_count,
vol_label.shape if vol_label is not None else None,
vol_label_np.shape if vol_label_np is not None else None,
label_count,
unique_labels_count,
)
Expand Down
6 changes: 3 additions & 3 deletions monai/apps/pathology/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union
from typing import List

import numpy as np
import torch

from monai.transforms.post.array import ProbNMS
from monai.utils import optional_import
from monai.utils.enums import DataObjects

measure, _ = optional_import("skimage.measure")
ndimage, _ = optional_import("scipy.ndimage")
Expand Down Expand Up @@ -67,7 +67,7 @@ class PathologyProbNMS(ProbNMS):

def __call__(
self,
probs_map: Union[np.ndarray, torch.Tensor],
probs_map: DataObjects.Images,
resolution_level: int = 0,
):
"""
Expand Down
7 changes: 4 additions & 3 deletions monai/data/csv_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import os
import warnings
from collections import OrderedDict
from typing import Dict, Optional, Union
from typing import Dict, Optional

import numpy as np
import torch

from monai.utils import ImageMetaKey as Key
from monai.utils.enums import DataObjects


class CSVSaver:
Expand Down Expand Up @@ -75,7 +76,7 @@ def finalize(self) -> None:
# clear cache content after writing
self.reset_cache()

def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save(self, data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None:
"""Save data into the cache dictionary. The metadata should have the following key:
- ``'filename_or_obj'`` -- save the data corresponding to file name or object.
If meta_data is None, use the default index from 0 to save data instead.
Expand All @@ -92,7 +93,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
data = data.detach().cpu().numpy()
self._cache_dict[save_key] = np.asarray(data, dtype=float)

def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save_batch(self, batch_data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None:
"""Save a batch of data into the cache dictionary.

Args:
Expand Down
5 changes: 3 additions & 2 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from monai.data.utils import create_file_basename
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils.enums import DataObjects


class NiftiSaver:
Expand Down Expand Up @@ -104,7 +105,7 @@ def __init__(
self.separate_folder = separate_folder
self.print_log = print_log

def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save(self, data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None:
"""
Save data into a Nifti file.
The meta_data could optionally have the following keys:
Expand Down Expand Up @@ -175,7 +176,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
if self.print_log:
print(f"file written: {path}.")

def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save_batch(self, batch_data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None:
"""
Save a batch of data into Nifti format files.

Expand Down
46 changes: 27 additions & 19 deletions monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from monai.data.utils import compute_shape_offset, to_affine_nd
from monai.networks.layers import AffineTransform
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import
from monai.utils.enums import DataObjects
from monai.utils.misc import convert_data_type

nib, _ = optional_import("nibabel")


def write_nifti(
data: np.ndarray,
data: DataObjects.Images,
file_name: str,
affine: Optional[np.ndarray] = None,
target_affine: Optional[np.ndarray] = None,
Expand All @@ -36,7 +38,7 @@ def write_nifti(
output_dtype: DtypeLike = np.float32,
) -> None:
"""
Write numpy data into NIfTI files to disk. This function converts data
Write numpy or torch data into NIfTI files to disk. This function converts data
into the coordinate system defined by `target_affine` when `target_affine`
is specified.

Expand Down Expand Up @@ -96,33 +98,39 @@ def write_nifti(
If None, use the data type of input data.
output_dtype: data type for saving data. Defaults to ``np.float32``.
"""
if not isinstance(data, np.ndarray):
raise AssertionError("input data must be numpy array.")
dtype = dtype or data.dtype
sr = min(data.ndim, 3)
if not isinstance(data, (np.ndarray, torch.Tensor)):
raise AssertionError("input data must be numpy array or torch.Tensor.")
# if torch, convert to numpy
data_np: np.ndarray
data_np, *_ = convert_data_type(data, np.ndarray) # type: ignore
if target_affine is not None:
target_affine, *_ = convert_data_type(target_affine, np.ndarray) # type: ignore

dtype = dtype or data_np.dtype
sr = min(data_np.ndim, 3)
if affine is None:
affine = np.eye(4, dtype=np.float64)
affine = to_affine_nd(sr, affine)
affine = to_affine_nd(sr, affine) # type: ignore

if target_affine is None:
target_affine = affine
target_affine = to_affine_nd(sr, target_affine)
target_affine = to_affine_nd(sr, target_affine) # type: ignore

if np.allclose(affine, target_affine, atol=1e-3):
# no affine changes, save (data, affine)
results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine))
results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, target_affine))
nib.save(results_img, file_name)
return

# resolve orientation
start_ornt = nib.orientations.io_orientation(affine)
target_ornt = nib.orientations.io_orientation(target_affine)
ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt)
data_shape = data.shape
data = nib.orientations.apply_orientation(data, ornt_transform)
data_shape = data_np.shape
data_np = nib.orientations.apply_orientation(data_np, ornt_transform)
_affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape)
if np.allclose(_affine, target_affine, atol=1e-3) or not resample:
results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, _affine))
results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, _affine))
nib.save(results_img, file_name)
return

Expand All @@ -132,13 +140,13 @@ def write_nifti(
)
transform = np.linalg.inv(_affine) @ target_affine
if output_spatial_shape is None:
output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine)
output_spatial_shape, _ = compute_shape_offset(data_np.shape, _affine, target_affine)
output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else []
if data.ndim > 3: # multi channel, resampling each channel
if data_np.ndim > 3: # multi channel, resampling each channel
while len(output_spatial_shape_) < 3:
output_spatial_shape_ = output_spatial_shape_ + [1]
spatial_shape, channel_shape = data.shape[:3], data.shape[3:]
data_np = data.reshape(list(spatial_shape) + [-1])
spatial_shape, channel_shape = data_np.shape[:3], data_np.shape[3:]
data_np = data_np.reshape(list(spatial_shape) + [-1])
data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch
data_torch = affine_xform(
torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0),
Expand All @@ -149,12 +157,12 @@ def write_nifti(
data_np = np.moveaxis(data_np, 0, -1) # channel last for nifti
data_np = data_np.reshape(list(data_np.shape[:3]) + list(channel_shape))
else: # single channel image, need to expand to have batch and channel
while len(output_spatial_shape_) < len(data.shape):
while len(output_spatial_shape_) < len(data_np.shape):
output_spatial_shape_ = output_spatial_shape_ + [1]
data_torch = affine_xform(
torch.as_tensor(np.ascontiguousarray(data).astype(dtype)[None, None]),
torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)[None, None]),
torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)),
spatial_size=output_spatial_shape_[: len(data.shape)],
spatial_size=output_spatial_shape_[: len(data_np.shape)],
)
data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy()

Expand Down
5 changes: 3 additions & 2 deletions monai/data/png_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from monai.data.utils import create_file_basename
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, look_up_option
from monai.utils.enums import DataObjects


class PNGSaver:
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(

self._data_index = 0

def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save(self, data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None:
"""
Save data into a png file.
The meta_data could optionally have the following keys:
Expand Down Expand Up @@ -143,7 +144,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
if self.print_log:
print(f"file written: {path}.")

def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save_batch(self, batch_data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None:
"""Save a batch of data into png format files.

Args:
Expand Down
Loading