Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5983 update to use np.linalg for the small affine inverse #5967

Merged
merged 11 commits into from
Feb 13, 2023
9 changes: 5 additions & 4 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from copy import deepcopy
from typing import Any

import numpy as np
import torch
import torch.nn as nn

from monai.apps.utils import get_logger
from monai.config import PathLike
from monai.utils.misc import ensure_tuple, save_obj, set_determinism
from monai.utils.module import look_up_option, pytorch_after
from monai.utils.type_conversion import convert_to_tensor
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor

__all__ = [
"one_hot",
Expand Down Expand Up @@ -185,7 +186,7 @@ def predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False,

def normalize_transform(
shape,
device: torch.device | None = None,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
align_corners: bool = False,
zero_centered: bool = False,
Expand Down Expand Up @@ -264,8 +265,8 @@ def to_norm_affine(
raise ValueError(f"affine suggests {sr}D, got src={len(src_size)}D, dst={len(dst_size)}D.")

src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners, zero_centered)
dst_xform = normalize_transform(dst_size, affine.device, affine.dtype, align_corners, zero_centered)
return src_xform @ affine @ torch.inverse(dst_xform)
dst_xform = normalize_transform(dst_size, "cpu", affine.dtype, align_corners, zero_centered)
wyli marked this conversation as resolved.
Show resolved Hide resolved
return src_xform @ affine @ convert_to_dst_type(np.linalg.inv(dst_xform.numpy()), dst=affine)[0]
wyli marked this conversation as resolved.
Show resolved Hide resolved


def normal_init(
Expand Down
29 changes: 13 additions & 16 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
fall_back_tuple,
issequenceiterable,
optional_import,
pytorch_after,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys
Expand Down Expand Up @@ -272,14 +271,12 @@ def __call__(
)

try:
_s = convert_to_tensor(src_affine_, track_meta=False, device=torch.device("cpu"))
_d = convert_to_tensor(dst_affine, track_meta=False, device=torch.device("cpu"))
xform = (
torch.linalg.solve(_s, _d) if pytorch_after(1, 8, 0) else torch.solve(_d, _s).solution # type: ignore
)
_s = convert_to_numpy(src_affine_)
_d = convert_to_numpy(dst_affine)
xform = np.linalg.solve(_s, _d)
except (np.linalg.LinAlgError, RuntimeError) as e:
raise ValueError("src affine is not invertible.") from e
xform = to_affine_nd(spatial_rank, xform).to(device=img.device, dtype=_dtype)
raise ValueError(f"src affine is not invertible {_s}, {_d}.") from e
xform = convert_to_tensor(to_affine_nd(spatial_rank, xform)).to(device=img.device, dtype=_dtype)
# no resampling if it's identity transform
if allclose(xform, torch.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size):
return self._post_process(
Expand All @@ -293,12 +290,12 @@ def __call__(
xform_shape = [-1] + in_spatial_size
img = img.reshape(xform_shape) # type: ignore
if isinstance(mode, int):
dst_xform_1 = normalize_transform(spatial_size, xform.device, xform.dtype, True, True)[0] # to (-1, 1)
dst_xform_1 = normalize_transform(spatial_size, "cpu", xform.dtype, True, True)[0].numpy() # to (-1, 1)
if not align_corners:
norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size], xform.device, "torch")
dst_xform_1 = norm.to(xform.dtype) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step
dst_xform_d = normalize_transform(spatial_size, xform.device, xform.dtype, align_corners, False)[0]
xform = xform @ torch.inverse(dst_xform_d) @ dst_xform_1
norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size])
dst_xform_1 = norm.astype(float) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step
dst_xform_d = normalize_transform(spatial_size, "cpu", xform.dtype, align_corners, False)[0].numpy()
xform @= convert_to_dst_type(np.linalg.solve(dst_xform_d, dst_xform_1), xform)[0]
wyli marked this conversation as resolved.
Show resolved Hide resolved
affine_xform = Affine(
affine=xform, spatial_size=spatial_size, normalized=True, image_only=True, dtype=_dtype # type: ignore
)
Expand Down Expand Up @@ -1084,7 +1081,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:
padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"]
align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"]
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
inv_rot_mat = linalg_inv(fwd_rot_mat)
inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat))

xform = AffineTransform(
normalized=False,
Expand Down Expand Up @@ -2281,7 +2278,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"]
mode = transform[TraceKeys.EXTRA_INFO]["mode"]
padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"]
inv_affine = linalg_inv(fwd_affine)
inv_affine = linalg_inv(convert_to_numpy(fwd_affine))
inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0]

affine_grid = AffineGrid(affine=inv_affine)
Expand Down Expand Up @@ -2520,7 +2517,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"]
mode = transform[TraceKeys.EXTRA_INFO]["mode"]
padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"]
inv_affine = linalg_inv(fwd_affine)
inv_affine = linalg_inv(convert_to_numpy(fwd_affine))
inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0]
affine_grid = AffineGrid(affine=inv_affine)
grid, _ = affine_grid(orig_size)
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def _create_shear(spatial_dims: int, coefs: Sequence[float] | float, eye_func=np
def create_scale(
spatial_dims: int,
scaling_factor: Sequence[float] | float,
device: torch.device | None = None,
device: torch.device | str | None = None,
backend=TransformBackends.NUMPY,
) -> NdarrayOrTensor:
"""
Expand Down
Loading