Skip to content

Commit

Permalink
update meshgrid (#3644)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Jan 12, 2022
1 parent 12955c5 commit d58e234
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 9 deletions.
3 changes: 2 additions & 1 deletion monai/networks/blocks/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.nn.functional import softmax

from monai.networks.layers.filtering import PHLFilter
from monai.networks.utils import meshgrid_ij

__all__ = ["CRF"]

Expand Down Expand Up @@ -114,6 +115,6 @@ def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor):
# helper methods
def _create_coordinate_tensor(tensor):
axes = [torch.arange(tensor.size(i)) for i in range(2, tensor.dim())]
grids = torch.meshgrid(axes)
grids = meshgrid_ij(axes)
coords = torch.stack(grids).to(device=tensor.device, dtype=tensor.dtype)
return torch.stack(tensor.size(0) * [coords], dim=0)
3 changes: 2 additions & 1 deletion monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from monai.config.deviceconfig import USE_COMPILED
from monai.networks.layers.spatial_transforms import grid_pull
from monai.networks.utils import meshgrid_ij
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import

_C, _ = optional_import("monai._C")
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa
@staticmethod
def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor:
mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]]
grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...)
grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...)
grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...)
grid = grid.to(ddf)
return grid
Expand Down
7 changes: 7 additions & 0 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"train_mode",
"copy_model_state",
"convert_to_torchscript",
"meshgrid_ij",
]


Expand Down Expand Up @@ -500,3 +501,9 @@ def convert_to_torchscript(
torch.testing.assert_allclose(r1, r2, rtol=rtol, atol=atol)

return script_module


def meshgrid_ij(*tensors):
if pytorch_after(1, 10):
return torch.meshgrid(*tensors, indexing="ij")
return torch.meshgrid(*tensors)
8 changes: 3 additions & 5 deletions monai/transforms/smooth_field/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

import monai
from monai.config.type_definitions import NdarrayOrTensor
from monai.networks.utils import meshgrid_ij
from monai.transforms.transform import Randomizable, RandomizableTransform
from monai.transforms.utils_pytorch_numpy_unification import moveaxis
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode
from monai.utils.enums import TransformBackends
from monai.utils.module import look_up_option, pytorch_after
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor

__all__ = ["SmoothField", "RandSmoothFieldAdjustContrast", "RandSmoothFieldAdjustIntensity", "RandSmoothDeform"]
Expand Down Expand Up @@ -404,10 +405,7 @@ def __init__(
grid_space = spatial_size if spatial_size is not None else self.sfield.field.shape[2:]
grid_ranges = [torch.linspace(-1, 1, d) for d in grid_space]

if pytorch_after(1, 10):
grid = torch.meshgrid(*grid_ranges, indexing="ij")
else:
grid = torch.meshgrid(*grid_ranges)
grid = meshgrid_ij(*grid_ranges)

self.grid = torch.stack(grid).unsqueeze(0).to(self.device, self.grid_dtype)

Expand Down
3 changes: 2 additions & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.networks.utils import meshgrid_ij
from monai.transforms.croppad.array import CenterSpatialCrop, Pad
from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform
from monai.transforms.utils import (
Expand Down Expand Up @@ -2103,7 +2104,7 @@ def __call__(
ranges = ranges - (dim_size - 1.0) / 2.0
all_ranges.append(ranges)

coords = torch.meshgrid(*all_ranges)
coords = meshgrid_ij(*all_ranges)
grid = torch.stack([*coords, torch.ones_like(coords[0])])

return self.resampler(img, grid=grid, mode=mode, padding_mode=padding_mode) # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion tests/test_grid_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parameterized import parameterized

from monai.networks.layers import grid_pull
from monai.networks.utils import meshgrid_ij
from monai.utils import optional_import
from tests.testing_data.cpp_resample_answers import Expected_1D_GP_bwd, Expected_1D_GP_fwd
from tests.utils import skip_if_no_cpp_extension
Expand All @@ -26,7 +27,7 @@

def make_grid(shape, dtype=None, device=None, requires_grad=True):
ranges = [torch.arange(float(s), dtype=dtype, device=device, requires_grad=requires_grad) for s in shape]
grid = torch.stack(torch.meshgrid(*ranges), dim=-1)
grid = torch.stack(meshgrid_ij(*ranges), dim=-1)
return grid[None]


Expand Down

0 comments on commit d58e234

Please sign in to comment.