Skip to content

Commit

Permalink
5983 update to use np.linalg for the small affine inverse (#5967)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>


Fixes #5983 
Fixes #5696 

### Description
- torch.inverse was less stable across different base image versions
- there's a bug for loading the module
pytorch/pytorch#90613
this therefore also closes
#5696

this PR tries to move away from the torch.linalg APIs for small inverse
tasks where there's no need for making them differentiable
![Screenshot 2023-02-09 at 23 42
15](https://user-images.githubusercontent.com/831580/217963728-7c96733d-8570-4643-8fe1-2b8867e8e511.png)

![Screenshot 2023-02-13 at 14 47
50](https://user-images.githubusercontent.com/831580/218489820-578ff684-6173-4d46-9603-779b0e2e2fe2.png)


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
  • Loading branch information
wyli authored Feb 13, 2023
1 parent 71abf1b commit 3122e1a
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 327 deletions.
9 changes: 5 additions & 4 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from copy import deepcopy
from typing import Any

import numpy as np
import torch
import torch.nn as nn

from monai.apps.utils import get_logger
from monai.config import PathLike
from monai.utils.misc import ensure_tuple, save_obj, set_determinism
from monai.utils.module import look_up_option, pytorch_after
from monai.utils.type_conversion import convert_to_tensor
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor

__all__ = [
"one_hot",
Expand Down Expand Up @@ -185,7 +186,7 @@ def predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False,

def normalize_transform(
shape,
device: torch.device | None = None,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
align_corners: bool = False,
zero_centered: bool = False,
Expand Down Expand Up @@ -264,8 +265,8 @@ def to_norm_affine(
raise ValueError(f"affine suggests {sr}D, got src={len(src_size)}D, dst={len(dst_size)}D.")

src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners, zero_centered)
dst_xform = normalize_transform(dst_size, affine.device, affine.dtype, align_corners, zero_centered)
return src_xform @ affine @ torch.inverse(dst_xform)
dst_xform = normalize_transform(dst_size, "cpu", affine.dtype, align_corners, zero_centered)
return src_xform @ affine @ convert_to_dst_type(np.linalg.inv(dst_xform.numpy()), dst=affine)[0] # monai#5983


def normal_init(
Expand Down
29 changes: 13 additions & 16 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
fall_back_tuple,
issequenceiterable,
optional_import,
pytorch_after,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys
Expand Down Expand Up @@ -272,14 +271,12 @@ def __call__(
)

try:
_s = convert_to_tensor(src_affine_, track_meta=False, device=torch.device("cpu"))
_d = convert_to_tensor(dst_affine, track_meta=False, device=torch.device("cpu"))
xform = (
torch.linalg.solve(_s, _d) if pytorch_after(1, 8, 0) else torch.solve(_d, _s).solution # type: ignore
)
_s = convert_to_numpy(src_affine_)
_d = convert_to_numpy(dst_affine)
xform = np.linalg.solve(_s, _d) # monai#5983
except (np.linalg.LinAlgError, RuntimeError) as e:
raise ValueError("src affine is not invertible.") from e
xform = to_affine_nd(spatial_rank, xform).to(device=img.device, dtype=_dtype)
raise ValueError(f"src affine is not invertible {_s}, {_d}.") from e
xform = convert_to_tensor(to_affine_nd(spatial_rank, xform)).to(device=img.device, dtype=_dtype)
# no resampling if it's identity transform
if allclose(xform, torch.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size):
return self._post_process(
Expand All @@ -293,12 +290,12 @@ def __call__(
xform_shape = [-1] + in_spatial_size
img = img.reshape(xform_shape) # type: ignore
if isinstance(mode, int):
dst_xform_1 = normalize_transform(spatial_size, xform.device, xform.dtype, True, True)[0] # to (-1, 1)
dst_xform_1 = normalize_transform(spatial_size, "cpu", xform.dtype, True, True)[0].numpy() # to (-1, 1)
if not align_corners:
norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size], xform.device, "torch")
dst_xform_1 = norm.to(xform.dtype) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step
dst_xform_d = normalize_transform(spatial_size, xform.device, xform.dtype, align_corners, False)[0]
xform = xform @ torch.inverse(dst_xform_d) @ dst_xform_1
norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size])
dst_xform_1 = norm.astype(float) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step
dst_xform_d = normalize_transform(spatial_size, "cpu", xform.dtype, align_corners, False)[0].numpy()
xform @= convert_to_dst_type(np.linalg.solve(dst_xform_d, dst_xform_1), xform)[0]
affine_xform = Affine(
affine=xform, spatial_size=spatial_size, normalized=True, image_only=True, dtype=_dtype # type: ignore
)
Expand Down Expand Up @@ -1084,7 +1081,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:
padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"]
align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"]
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
inv_rot_mat = linalg_inv(fwd_rot_mat)
inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat))

xform = AffineTransform(
normalized=False,
Expand Down Expand Up @@ -2281,7 +2278,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"]
mode = transform[TraceKeys.EXTRA_INFO]["mode"]
padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"]
inv_affine = linalg_inv(fwd_affine)
inv_affine = linalg_inv(convert_to_numpy(fwd_affine))
inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0]

affine_grid = AffineGrid(affine=inv_affine)
Expand Down Expand Up @@ -2520,7 +2517,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"]
mode = transform[TraceKeys.EXTRA_INFO]["mode"]
padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"]
inv_affine = linalg_inv(fwd_affine)
inv_affine = linalg_inv(convert_to_numpy(fwd_affine))
inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0]
affine_grid = AffineGrid(affine=inv_affine)
grid, _ = affine_grid(orig_size)
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def _create_shear(spatial_dims: int, coefs: Sequence[float] | float, eye_func=np
def create_scale(
spatial_dims: int,
scaling_factor: Sequence[float] | float,
device: torch.device | None = None,
device: torch.device | str | None = None,
backend=TransformBackends.NUMPY,
) -> NdarrayOrTensor:
"""
Expand Down
Loading

0 comments on commit 3122e1a

Please sign in to comment.