-
Notifications
You must be signed in to change notification settings - Fork 0
/
measurements.py
executable file
·214 lines (160 loc) · 6.45 KB
/
measurements.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""Measurements module.
Handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.
Inspired by DPS repository:
- https://github.com/DPS2022/diffusion-posterior-sampling/blob/main/guided_diffusion/measurements.py
"""
import abc
import numpy as np
from jax import tree_util
from keras import ops
from utils import mri
from utils.keras_utils import check_keras_backend
check_keras_backend()
_OPERATORS = {}
def register_operator(cls=None, *, name=None):
"""A decorator for registering operator classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _OPERATORS:
raise ValueError(f"Already registered operator with name: {local_name}")
_OPERATORS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_operator(name):
"""Get operator class for given name."""
assert (
name in _OPERATORS
), f"Operator {name} not found. Available operators: {_OPERATORS.keys()}"
return _OPERATORS[name]
class LinearOperator(abc.ABC):
"""Linear operator class y = Ax + n."""
sigma = 0.0
@abc.abstractmethod
def forward(self, data):
"""Implements the forward operator A: x -> y."""
raise NotImplementedError
@abc.abstractmethod
def corrupt(self, data):
"""Corrupt the data. Similar to forward but with noise."""
raise NotImplementedError
@abc.abstractmethod
def transpose(self, data):
"""Implements the transpose operator A^T: y -> x."""
raise NotImplementedError
@abc.abstractmethod
def __str__(self):
"""String representation of the operator."""
raise NotImplementedError
@classmethod
def _tree_unflatten(cls, aux, children):
return cls(*children)
def _tree_flatten(self):
return (), ()
@register_operator(name="inpainting")
class InpaintingOperator(LinearOperator):
"""Inpainting operator A = I * M."""
def __init__(self, mask):
self.mask = mask
def forward(self, data):
return data * self.mask
def corrupt(self, data):
return self.forward(data)
def transpose(self, data):
return data * self.mask
def __str__(self):
return "y = Ax + n, where A = I * M"
def _tree_flatten(self):
return (self.mask,), ()
@register_operator(name="fourier")
class FourierOperator(LinearOperator):
"""Fourier operator A = F."""
def forward(self, data):
return mri.fft2c(data)
def corrupt(self, data):
return mri.fft2c(data)
def transpose(self, data):
# Fourier transform is unitary --> adjoint is inverse
# https://math.stackexchange.com/questions/1429086/prove-the-fourier-transform-is-a-unitary-linear-operator
raise mri.ifft2c(data)
def __str__(self):
return "y = F(x)"
@register_operator(name="masked_fourier")
class MaskedFourierOperator(LinearOperator):
"""Masked Fourier operator A = M*F, where M is a binary mask"""
def __init__(self, mask):
self.mask = mask
def forward(self, data):
return self.mask * mri.fft2c(data)
def corrupt(self, data):
return self.mask * mri.fft2c(data)
def transpose(self, data):
# Fourier transform is unitary --> adjoint is inverse
# https://math.stackexchange.com/questions/1429086/prove-the-fourier-transform-is-a-unitary-linear-operator
raise self.mask * mri.ifft2c(data)
def __str__(self):
return "y = M*F(x)"
def _tree_flatten(self):
return (self.mask,), ()
def prepare_measurement(operator_name, target_imgs, **measurement_kwargs):
"""
Prepare measurement given operator name and target images.
Just an easy way of quickly generating random measurements given clean images.
Args:
operator_name (str): The name of the operator to be used for the measurement process.
target_imgs (Tensor): The target images for the measurement.
measurement_kwargs (dict, optional): Additional keyword arguments to be passed to the operator.
If not specified, default values will be used for each operator.
Returns:
tuple: A tuple containing the operator and the measurements.
- operator: The chosen forward operator for the measurement process.
- measurements: The corrupted measurements obtained using the chosen operator.
Raises:
ValueError: If the specified `operator_name` is not recognized.
Note:
- The function supports the following operator names:
- "inpainting": Inpainting operator.
- "masked_fourier": Operator first applying Fourier transform, and then a mask.
"""
operator = get_operator(operator_name)
# set defaults for each operator -- configurable by changing operator.mask
if not measurement_kwargs:
if operator_name == "masked_fourier":
# default to a centered 4x acceleration mask.
mask = (
ops.zeros_like(target_imgs.shape[1:]).at[0, 32 + 16 : 64 + 16, 0].set(1)
)
measurement_kwargs = {"mask": mask}
elif operator_name == "inpainting":
# default to a mask hiding half pixels in the image.
# Build measurement mask
image_shape = target_imgs.shape[1:]
mask = np.zeros(image_shape, dtype="float32")
# mask out random half of pixels of the image
n_total_samples = image_shape[0] * image_shape[1]
random_idx = np.random.choice(
n_total_samples, size=n_total_samples // 2, replace=False
)
random_idx = np.unravel_index(random_idx, image_shape[:-1])
mask[random_idx] = 1
mask = mask[None, :, :] # add batch dimension
measurement_kwargs = {"mask": mask}
else:
raise ValueError(f"Operator `{operator_name}` not recognised.")
operator = operator(**measurement_kwargs)
measurements = operator.corrupt(target_imgs)
return operator, measurements
# register all classes for jax tree flattening
# allows us to use operator class as arguments in jitted jax functions
# https://jax.readthedocs.io/en/latest/faq.html#strategy-3-making-customclass-a-pytree
for cls in LinearOperator.__subclasses__():
tree_util.register_pytree_node(
cls,
cls._tree_flatten,
cls._tree_unflatten,
)