diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b1f1bbd0f6..32fffc25f0 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1863,7 +1863,7 @@ class Fourier: """ @staticmethod - def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: + def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor: """ Applies fourier transform and shifts the zero-frequency component to the center of the spectrum. Only the spatial dimensions get transformed. @@ -1871,6 +1871,7 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: Args: x: Image to transform. spatial_dims: Number of spatial dimensions. + as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous. Returns k: K-space data. @@ -1885,10 +1886,12 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims) else: k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims) - return k + return ascontiguousarray(k) if as_contiguous else k @staticmethod - def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None) -> NdarrayOrTensor: + def inv_shift_fourier( + k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None, as_contiguous: bool = False + ) -> NdarrayOrTensor: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. @@ -1896,6 +1899,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None Args: k: K-space data. spatial_dims: Number of spatial dimensions. + as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous. Returns: x: Tensor in image space. @@ -1910,7 +1914,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real else: out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real - return out + return ascontiguousarray(out) if as_contiguous else out def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int: