Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def _apply_transform(
Otherwise `parameters` is considered as single argument to `transform`.

Args:
transform (Callable[..., ReturnType]): a callable to be used to transform `data`.
parameters (Any): parameters for the `transform`.
unpack_parameters (bool, optional): whether to unpack parameters for `transform`. Defaults to False.
transform: a callable to be used to transform `data`.
parameters: parameters for the `transform`.
unpack_parameters: whether to unpack parameters for `transform`. Defaults to False.

Returns:
ReturnType: The return type of `transform`.
Expand All @@ -64,11 +64,11 @@ def apply_transform(
otherwise transform will be applied once with `data` as the argument.

Args:
transform (Callable[..., ReturnType]): a callable to be used to transform `data`.
data (Any): an object to be transformed.
map_items (bool, optional): whether to apply transform to each item in `data`,
transform: a callable to be used to transform `data`.
data: an object to be transformed.
map_items: whether to apply transform to each item in `data`,
if `data` is a list or tuple. Defaults to True.
unpack_items (bool, optional): [description]. Defaults to False.
unpack_items: whether to unpack parameters using `*`. Defaults to False.

Raises:
Exception: When ``transform`` raises an exception.
Expand Down Expand Up @@ -216,17 +216,15 @@ def __call__(self, data: Any):
return an updated version of ``data``.
To simplify the input validations, most of the transforms assume that

- ``data`` is a Numpy ndarray, PyTorch Tensor or string
- ``data`` is a Numpy ndarray, PyTorch Tensor or string,
- the data shape can be:

#. string data without shape, `LoadImage` transform expects file paths
#. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
except that `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and
`AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels)
#. most of the post-processing transforms expect
``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])``
#. string data without shape, `LoadImage` transform expects file paths,
#. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
except for example: `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and
`AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels),

- the channel dimension is not omitted even if number of channels is one
- the channel dimension is often not omitted even if number of channels is one.

This method can optionally take additional arguments to help execute transformation operation.

Expand Down Expand Up @@ -323,18 +321,16 @@ def __call__(self, data):

To simplify the input validations, this method assumes:

- ``data`` is a Python dictionary
- ``data`` is a Python dictionary,
- ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element
of ``self.keys``, the data shape can be:

#. string data without shape, `LoadImaged` transform expects file paths
#. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
except that `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and
#. string data without shape, `LoadImaged` transform expects file paths,
#. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
except for example: `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and
`AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels)
#. most of the post-processing transforms expect
``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])``

- the channel dimension is not omitted even if number of channels is one
- the channel dimension is often not omitted even if number of channels is one.

Raises:
NotImplementedError: When the subclass does not override this method.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_decollate.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def check_match(self, in1, in2):

def check_decollate(self, dataset):
batch_size = 2
num_workers = 2
num_workers = 2 if sys.platform == "linux" else 0

loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_inverse_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_collation(self, _, transform, collate_fn, ndim):
modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS)])

# num workers = 0 for mac or gpu transforms
num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2
num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2

dataset = CacheDataset(data, transform=modified_transform, progress=False)
loader = DataLoader(dataset, num_workers, batch_size=self.batch_size, collate_fn=collate_fn)
Expand Down