-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathbase.py
764 lines (639 loc) · 24 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
"""
Base Fourier Operator interface.
from https://github.com/CEA-COSMIC/pysap-mri
:author: Pierre-Antoine Comby
"""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from functools import partial, wraps
import numpy as np
from mrinufft._utils import auto_cast, get_array_module, power_method
from mrinufft.density import get_density
from mrinufft.extras import get_smaps
from mrinufft.operators.interfaces.utils import is_cuda_array, is_host_array
CUPY_AVAILABLE = True
try:
import cupy as cp
except ImportError:
CUPY_AVAILABLE = False
AUTOGRAD_AVAILABLE = True
try:
import torch
from mrinufft.operators.autodiff import MRINufftAutoGrad
except ImportError:
AUTOGRAD_AVAILABLE = False
# Mapping between numpy float and complex types.
DTYPE_R2C = {"float32": "complex64", "float64": "complex128"}
def check_backend(backend_name: str):
"""Check if a specific backend is available."""
backend_name = backend_name.lower()
try:
return FourierOperatorBase.interfaces[backend_name][0]
except KeyError as e:
raise ValueError(f"unknown backend: '{backend_name}'") from e
def list_backends(available_only=False):
"""Return a list of backend.
Parameters
----------
available_only: bool, optional
If True, only return backends that are available. If False, return all
backends, regardless of whether they are available or not.
"""
return [
name
for name, (available, _) in FourierOperatorBase.interfaces.items()
if available or not available_only
]
def get_operator(
backend_name: str, wrt_data: bool = False, wrt_traj: bool = False, *args, **kwargs
):
"""Return an MRI Fourier operator interface using the correct backend.
Parameters
----------
backend_name: str
Backend name
wrt_data: bool, default False
if set gradients wrt to data and images will be available.
wrt_traj: bool, default False
if set gradients wrt to trajectory will be available.
*args, **kwargs:
Arguments to pass to the operator constructor.
Returns
-------
FourierOperator
class or instance of class if args or kwargs are given.
Raises
------
ValueError if the backend is not available.
"""
available = True
backend_name = backend_name.lower()
try:
available, operator = FourierOperatorBase.interfaces[backend_name]
except KeyError as exc:
if not backend_name.startswith("stacked-"):
raise ValueError(f"backend {backend_name} does not exist") from exc
# try to get the backend with stacked
# Dedicated registered stacked backend (like stacked-cufinufft)
# have be found earlier.
backend = backend_name.split("-")[1]
operator = get_operator("stacked")
operator = partial(operator, backend=backend)
if not available:
raise ValueError(f"backend {backend_name} found, but dependencies are not met.")
if args or kwargs:
operator = operator(*args, **kwargs)
# if autograd:
if isinstance(operator, FourierOperatorBase):
operator = operator.make_autograd(wrt_data, wrt_traj)
elif wrt_data or wrt_traj: # instance will be created later
operator = partial(operator.with_autograd, wrt_data, wrt_traj)
return operator
def with_numpy(fun):
"""Ensure the function works internally with numpy array."""
@wraps(fun)
def wrapper(self, data, *args, **kwargs):
if hasattr(data, "__cuda_array_interface__"):
warnings.warn("data is on gpu, it will be moved to CPU.")
xp = get_array_module(data)
if xp.__name__ == "torch":
data_ = data.to("cpu").numpy()
elif xp.__name__ == "cupy":
data_ = data.get()
elif xp.__name__ == "numpy":
data_ = data
else:
raise ValueError(f"Array library {xp} not supported.")
ret_ = fun(self, data_, *args, **kwargs)
if xp.__name__ == "torch":
if data.is_cpu:
return xp.from_numpy(ret_)
return xp.from_numpy(ret_).to(data.device)
elif xp.__name__ == "cupy":
return xp.array(ret_)
else:
return ret_
return wrapper
def with_numpy_cupy(fun):
"""Ensure the function works internally with numpy or cupy array."""
@wraps(fun)
def wrapper(self, data, output=None, *args, **kwargs):
xp = get_array_module(data)
if xp.__name__ == "torch" and is_cuda_array(data):
# Move them to cupy
data_ = cp.from_dlpack(data)
output_ = cp.from_dlpack(output) if output is not None else None
elif xp.__name__ == "torch":
# Move to numpy
data_ = data.to("cpu").numpy()
output_ = output.to("cpu").numpy() if output is not None else None
else:
data_ = data
output_ = output
if output_ is not None:
if not (
(is_host_array(data_) and is_host_array(output_))
or (is_cuda_array(data_) and is_cuda_array(output_))
):
raise ValueError(
"input data and output should be " "on the same memory space."
)
ret_ = fun(self, data_, output_, *args, **kwargs)
if xp.__name__ == "torch" and is_cuda_array(data):
return xp.as_tensor(ret_, device=data.device)
if xp.__name__ == "torch":
if data.is_cpu:
return xp.from_numpy(ret_)
return xp.from_numpy(ret_).to(data.device)
return ret_
return wrapper
class FourierOperatorBase(ABC):
"""Base Fourier Operator class.
Every (Linear) Fourier operator inherits from this class,
to ensure that we have all the functions rightly implemented
as required by ModOpt.
"""
interfaces: dict[str, tuple] = {}
autograd_available = False
def __init__(self):
if not self.available:
raise RuntimeError(f"'{self.backend}' backend is not available.")
self._smaps = None
self._density = None
self._n_coils = 1
def __init_subclass__(cls):
"""Register the class in the list of available operators."""
super().__init_subclass__()
available = getattr(cls, "available", True)
if callable(available):
available = available()
if backend := getattr(cls, "backend", None):
cls.interfaces[backend] = (available, cls)
@abstractmethod
def op(self, data):
"""Compute operator transform.
Parameters
----------
data: np.ndarray
input as array.
Returns
-------
result: np.ndarray
operator transform of the input.
"""
pass
@abstractmethod
def adj_op(self, coeffs):
"""Compute adjoint operator transform.
Parameters
----------
x: np.ndarray
input data array.
Returns
-------
results: np.ndarray
adjoint operator transform.
"""
pass
def data_consistency(self, image, obs_data):
"""Compute the gradient data consistency.
This is the naive implementation using adj_op(op(x)-y).
Specific backend can (and should!) implement a more efficient version.
"""
return self.adj_op(self.op(image) - obs_data)
def with_off_resonnance_correction(self, B, C, indices):
"""Return a new operator with Off Resonnance Correction."""
from ..off_resonnance import MRIFourierCorrected
return MRIFourierCorrected(self, B, C, indices)
def compute_smaps(self, method=None):
"""Compute the sensitivity maps and set it.
Parameters
----------
method: callable or dict or array
The method to use to compute the sensitivity maps.
If an array, it should be of shape (NCoils,XYZ) and will be used as is.
If a dict, it should have a key 'name', to determine which method to use.
other items will be used as kwargs.
If a callable, it should take the samples and the shape as input.
Note that this callable function should also hold the k-space data
(use funtools.partial)
"""
if isinstance(method, np.ndarray):
self.smaps = method
return None
if not method:
self.smaps = None
return None
kwargs = {}
if isinstance(method, dict):
kwargs = method.copy()
method = kwargs.pop("name")
if isinstance(method, str):
method = get_smaps(method)
if not callable(method):
raise ValueError(f"Unknown smaps method: {method}")
self.smaps, self.SOS = method(
self.samples,
self.shape,
density=self.density,
backend=self.backend,
**kwargs,
)
def make_autograd(self, wrt_data=True, wrt_traj=False):
"""Make a new Operator with autodiff support.
Parameters
----------
variable: , default data
variable on which the gradient is computed with respect to.
wrt_data : bool, optional
If the gradient with respect to the data is computed, default is true
wrt_traj : bool, optional
If the gradient with respect to the trajectory is computed, default is false
Returns
-------
torch.nn.module
A NUFFT operator with autodiff capabilities.
Raises
------
ValueError
If autograd is not available.
"""
if not AUTOGRAD_AVAILABLE:
raise ValueError("Autograd not available, ensure torch is installed.")
if not self.autograd_available:
raise ValueError("Backend does not support auto-differentiation.")
return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj)
def compute_density(self, method=None):
"""Compute the density compensation weights and set it.
Parameters
----------
method: str or callable or array or dict or bool
The method to use to compute the density compensation.
If a string, the method should be registered in the density registry.
If a callable, it should take the samples and the shape as input.
If a dict, it should have a key 'name', to determine which method to use.
other items will be used as kwargs.
If an array, it should be of shape (Nsamples,) and will be used as is.
If a bool, it will enable or disable the density compensation.
"""
if not method:
self.density = None
return None
if method is True and hasattr(self, "pipe"):
method = "pipe"
if isinstance(method, np.ndarray):
self.density = method
return None
kwargs = {}
if isinstance(method, dict):
kwargs = method.copy()
method = kwargs.pop("name") # must be a string !
if method == "pipe" and "backend" not in kwargs:
kwargs["backend"] = self.backend
if isinstance(method, str):
method = get_density(method)
if not callable(method):
raise ValueError(f"Unknown density method: {method}")
self.density = method(self.samples, self.shape, **kwargs)
def get_lipschitz_cst(self, max_iter=10, **kwargs):
"""Return the Lipschitz constant of the operator.
Parameters
----------
max_iter: int
number of iteration to compute the lipschitz constant.
**kwargs:
Extra arguments givent
Returns
-------
float
Spectral Radius
Notes
-----
This uses the Iterative Power Method to compute the largest singular value of a
minified version of the nufft operator. No coil or B0 compensation is used,
but includes any computed density.
"""
if self.n_coils > 1:
tmp_op = self.__class__(
self.samples, self.shape, density=self.density, n_coils=1, **kwargs
)
else:
tmp_op = self
return power_method(max_iter, tmp_op)
@property
def uses_sense(self):
"""Return True if the operator uses sensitivity maps."""
return self._smaps is not None
@property
def uses_density(self):
"""Return True if the operator uses density compensation."""
return getattr(self, "density", None) is not None
@property
def ndim(self):
"""Number of dimensions in image space of the operator."""
return len(self._shape)
@property
def shape(self):
"""Shape of the image space of the operator."""
return self._shape
@shape.setter
def shape(self, shape):
self._shape = tuple(shape)
@property
def n_coils(self):
"""Number of coils for the operator."""
return self._n_coils
@n_coils.setter
def n_coils(self, n_coils):
if n_coils < 1 or not int(n_coils) == n_coils:
raise ValueError(f"n_coils should be a positive integer, {type(n_coils)}")
self._n_coils = int(n_coils)
@property
def smaps(self):
"""Sensitivity maps of the operator."""
return self._smaps
@smaps.setter
def smaps(self, smaps):
if smaps is None:
self._smaps = None
elif len(smaps) != self.n_coils:
raise ValueError(
f"Number of sensitivity maps ({len(smaps)})"
f"should be equal to n_coils ({self.n_coils})"
)
else:
self._smaps = smaps
@property
def density(self):
"""Density compensation of the operator."""
return self._density
@density.setter
def density(self, density):
if density is None:
self._density = None
elif len(density) != self.n_samples:
raise ValueError("Density and samples should have the same length")
else:
self._density = density
@property
def dtype(self):
"""Return floating precision of the operator."""
return self._dtype
@dtype.setter
def dtype(self, dtype):
self._dtype = np.dtype(dtype)
@property
def cpx_dtype(self):
"""Return complex floating precision of the operator."""
return np.dtype(DTYPE_R2C[str(self.dtype)])
@property
def samples(self):
"""Return the samples used by the operator."""
return self._samples
@samples.setter
def samples(self, samples):
self._samples = samples
@property
def n_samples(self):
"""Return the number of samples used by the operator."""
return self._samples.shape[0]
@property
def norm_factor(self):
"""Normalization factor of the operator."""
return np.sqrt(np.prod(self.shape) * (2 ** len(self.shape)))
def __repr__(self):
"""Return info about the Fourier operator."""
return (
f"{self.__class__.__name__}(\n"
f" shape: {self.shape}\n"
f" n_coils: {self.n_coils}\n"
f" n_samples: {self.n_samples}\n"
f" uses_sense: {self.uses_sense}\n"
")"
)
@classmethod
def with_autograd(cls, wrt_data=True, wrt_traj=False, *args, **kwargs):
"""Return a Fourier operator with autograd capabilities."""
return cls(*args, **kwargs).make_autograd(wrt_data, wrt_traj)
class FourierOperatorCPU(FourierOperatorBase):
"""Base class for CPU-based NUFFT operator.
The NUFFT operation will be done sequentially and looped over coils and batches.
Parameters
----------
samples: np.ndarray
The samples used by the operator.
shape: tuple
The shape of the image space (in 2D or 3D)
density: bool or np.ndarray
If True, the density compensation is estimated from the samples.
If False, no density compensation is applied.
If np.ndarray, the density compensation is applied from the array.
n_coils: int
The number of coils.
smaps: np.ndarray
The sensitivity maps.
raw_op: object
An object implementing the NUFFT API. Ut should be responsible to compute a
single type 1 /type 2 NUFFT.
"""
def __init__(
self,
samples,
shape,
density=False,
n_coils=1,
n_batchs=1,
n_trans=1,
smaps=None,
raw_op=None,
squeeze_dims=True,
):
super().__init__()
self.shape = shape
# we will access the samples by their coordinate first.
self.samples = samples.reshape(-1, len(shape))
self.dtype = self.samples.dtype
if n_coils < 1:
raise ValueError("n_coils should be ≥ 1")
self.n_coils = n_coils
self.n_batchs = n_batchs
self.n_trans = n_trans
self.squeeze_dims = squeeze_dims
# Density Compensation Setup
self.compute_density(density)
# Multi Coil Setup
self.compute_smaps(smaps)
self.raw_op = raw_op
@with_numpy
def op(self, data, ksp=None):
r"""Non Cartesian MRI forward operator.
Parameters
----------
data: np.ndarray
The uniform (2D or 3D) data in image space.
Returns
-------
Results array on the same device as data.
Notes
-----
this performs for every coil \ell:
..math:: \mathcal{F}\mathcal{S}_\ell x
"""
# sense
data = auto_cast(data, self.cpx_dtype)
if self.uses_sense:
ret = self._op_sense(data, ksp)
# calibrationless or monocoil.
else:
ret = self._op_calibless(data, ksp)
ret /= self.norm_factor
ret = self._safe_squeeze(ret)
return ret
def _op_sense(self, data, ksp=None):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape
dataf = data.reshape((B, *XYZ))
if ksp is None:
ksp = np.empty((B * C, K), dtype=self.cpx_dtype)
for i in range(B * C // T):
idx_coils = np.arange(i * T, (i + 1) * T) % C
idx_batch = np.arange(i * T, (i + 1) * T) // C
coil_img = self.smaps[idx_coils].copy().reshape((T, *XYZ))
coil_img *= dataf[idx_batch]
self._op(coil_img, ksp[i * T : (i + 1) * T])
ksp = ksp.reshape((B, C, K))
return ksp
def _op_calibless(self, data, ksp=None):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape
if ksp is None:
ksp = np.empty((B * C, K), dtype=self.cpx_dtype)
dataf = np.reshape(data, (B * C, *XYZ))
for i in range((B * C) // T):
self._op(
dataf[i * T : (i + 1) * T],
ksp[i * T : (i + 1) * T],
)
ksp = ksp.reshape((B, C, K))
return ksp
def _op(self, image, coeffs):
self.raw_op.op(coeffs, image)
@with_numpy
def adj_op(self, coeffs, img=None):
"""Non Cartesian MRI adjoint operator.
Parameters
----------
coeffs: np.array or GPUArray
Returns
-------
Array in the same memory space of coeffs. (ie on cpu or gpu Memory).
"""
coeffs = auto_cast(coeffs, self.cpx_dtype)
if self.uses_sense:
ret = self._adj_op_sense(coeffs, img)
# calibrationless or monocoil.
else:
ret = self._adj_op_calibless(coeffs, img)
ret /= self.norm_factor
return self._safe_squeeze(ret)
def _adj_op_sense(self, coeffs, img=None):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape
if img is None:
img = np.zeros((B, *XYZ), dtype=self.cpx_dtype)
coeffs_flat = coeffs.reshape((B * C, K))
img_batched = np.zeros((T, *XYZ), dtype=self.cpx_dtype)
for i in range(B * C // T):
idx_coils = np.arange(i * T, (i + 1) * T) % C
idx_batch = np.arange(i * T, (i + 1) * T) // C
self._adj_op(coeffs_flat[i * T : (i + 1) * T], img_batched)
img_batched *= self.smaps[idx_coils].conj()
for t, b in enumerate(idx_batch):
img[b] += img_batched[t]
img = img.reshape((B, 1, *XYZ))
return img
def _adj_op_calibless(self, coeffs, img=None):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape
if img is None:
img = np.empty((B * C, *XYZ), dtype=self.cpx_dtype)
coeffs_f = np.reshape(coeffs, (B * C, K))
for i in range((B * C) // T):
self._adj_op(coeffs_f[i * T : (i + 1) * T], img[i * T : (i + 1) * T])
img = img.reshape((B, C, *XYZ))
return img
def _adj_op(self, coeffs, image):
if self.density is not None:
coeffs2 = coeffs.copy()
for i in range(self.n_trans):
coeffs2[i * self.n_samples : (i + 1) * self.n_samples] *= self.density
else:
coeffs2 = coeffs
self.raw_op.adj_op(coeffs2, image)
def data_consistency(self, image_data, obs_data):
"""Compute the gradient data consistency.
This mixes the op and adj_op method to perform F_adj(F(x-y))
on a per coil basis. By doing the computation coil wise,
it uses less memory than the naive call to adj_op(op(x)-y)
Parameters
----------
image: array
Image on which the gradient operation will be evaluated.
N_coil x Image shape is not using sense.
obs_data: array
Observed data.
"""
if self.uses_sense:
return self._safe_squeeze(self._grad_sense(image_data, obs_data))
return self._safe_squeeze(self._grad_calibless(image_data, obs_data))
def _grad_sense(self, image_data, obs_data):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape
dataf = image_data.reshape((B, *XYZ))
obs_dataf = obs_data.reshape((B * C, K))
grad = np.empty_like(dataf)
coil_img = np.empty((T, *XYZ), dtype=self.cpx_dtype)
coil_ksp = np.empty((T, K), dtype=self.cpx_dtype)
for i in range(B * C // T):
idx_coils = np.arange(i * T, (i + 1) * T) % C
idx_batch = np.arange(i * T, (i + 1) * T) // C
coil_img = self.smaps[idx_coils].copy().reshape((T, *XYZ))
coil_img *= dataf[idx_batch]
self._op(coil_img, coil_ksp)
coil_ksp /= self.norm_factor
coil_ksp -= obs_dataf[i * T : (i + 1) * T]
self._adj_op(coil_ksp, coil_img)
coil_img *= self.smaps[idx_coils].conj()
for t, b in enumerate(idx_batch):
grad[b] += coil_img[t]
grad /= self.norm_factor
return grad
def _grad_calibless(self, image_data, obs_data):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape
dataf = image_data.reshape((B * C, *XYZ))
obs_dataf = obs_data.reshape((B * C, K))
grad = np.empty_like(dataf)
ksp = np.empty((T, K), dtype=self.cpx_dtype)
for i in range(B * C // T):
self._op(dataf[i * T : (i + 1) * T], ksp)
ksp /= self.norm_factor
ksp -= obs_dataf[i * T : (i + 1) * T]
if self.uses_density:
ksp *= self.density
self._adj_op(ksp, grad[i * T : (i + 1) * T])
grad /= self.norm_factor
return grad.reshape(B, C, *XYZ)
def _safe_squeeze(self, arr):
"""Squeeze the first two dimensions of shape of the operator."""
if self.squeeze_dims:
try:
arr = arr.squeeze(axis=1)
except ValueError:
pass
try:
arr = arr.squeeze(axis=0)
except ValueError:
pass
return arr