Skip to content

Commit

Permalink
Added as_contiguous arg following recommedations from community.
Browse files Browse the repository at this point in the history
  • Loading branch information
bwittmann committed Sep 3, 2024
1 parent fb65555 commit 4df5947
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,14 +1633,15 @@ class Fourier:
"""

@staticmethod
def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor:
"""
Applies fourier transform and shifts the zero-frequency component to the
center of the spectrum. Only the spatial dimensions get transformed.
Args:
x: Image to transform.
spatial_dims: Number of spatial dimensions.
as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
Returns
k: K-space data.
Expand All @@ -1649,23 +1650,24 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
k: NdarrayOrTensor
if isinstance(x, torch.Tensor):
if hasattr(torch.fft, "fftshift"): # `fftshift` is new in torch 1.8.0
k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims).contiguous()
k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims)
else:
# if using old PyTorch, will convert to numpy array and return
k = np.ascontiguousarray(np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims))
k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims)
else:
k = np.ascontiguousarray(np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims))
return k
k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims)
return ascontiguousarray(k) if as_contiguous else k

@staticmethod
def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None) -> NdarrayOrTensor:
def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None, as_contiguous: bool = False) -> NdarrayOrTensor:
"""
Applies inverse shift and fourier transform. Only the spatial
dimensions are transformed.
Args:
k: K-space data.
spatial_dims: Number of spatial dimensions.
as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
Returns:
x: Tensor in image space.
Expand All @@ -1674,13 +1676,13 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None
out: NdarrayOrTensor
if isinstance(k, torch.Tensor):
if hasattr(torch.fft, "ifftshift"): # `ifftshift` is new in torch 1.8.0
out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward").real.contiguous()
out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward").real
else:
# if using old PyTorch, will convert to numpy array and return
out = np.ascontiguousarray(np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real)
out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real
else:
out = np.ascontiguousarray(np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real)
return out
out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real
return ascontiguousarray(out) if as_contiguous else out


def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int:
Expand Down

0 comments on commit 4df5947

Please sign in to comment.