-
Notifications
You must be signed in to change notification settings - Fork 3
/
AcqInfo.py
260 lines (200 loc) · 9.31 KB
/
AcqInfo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""Acquisition information dataclass."""
from collections.abc import Sequence
from dataclasses import dataclass
import ismrmrd
import numpy as np
import torch
from einops import rearrange
from typing_extensions import Self
from mrpro.data.MoveDataMixin import MoveDataMixin
from mrpro.data.Rotation import Rotation
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.utils.unit_conversion import mm_to_m
def rearrange_acq_info_fields(field: object, pattern: str, **axes_lengths: dict[str, int]) -> object:
"""Change the shape of the fields in AcqInfo."""
if isinstance(field, Rotation):
return Rotation.from_matrix(rearrange(field.as_matrix(), pattern, **axes_lengths))
if isinstance(field, torch.Tensor):
return rearrange(field, pattern, **axes_lengths)
return field
@dataclass(slots=True)
class AcqIdx(MoveDataMixin):
"""Acquisition index for each readout."""
k1: torch.Tensor
"""First phase encoding."""
k2: torch.Tensor
"""Second phase encoding."""
average: torch.Tensor
"""Signal average."""
slice: torch.Tensor
"""Slice number (multi-slice 2D)."""
contrast: torch.Tensor
"""Echo number in multi-echo."""
phase: torch.Tensor
"""Cardiac phase."""
repetition: torch.Tensor
"""Counter in repeated/dynamic acquisitions."""
set: torch.Tensor
"""Sets of different preparation, e.g. flow encoding, diffusion weighting."""
segment: torch.Tensor
"""Counter for segmented acquisitions."""
user0: torch.Tensor
"""User index 0."""
user1: torch.Tensor
"""User index 1."""
user2: torch.Tensor
"""User index 2."""
user3: torch.Tensor
"""User index 3."""
user4: torch.Tensor
"""User index 4."""
user5: torch.Tensor
"""User index 5."""
user6: torch.Tensor
"""User index 6."""
user7: torch.Tensor
"""User index 7."""
@dataclass(slots=True)
class AcqInfo(MoveDataMixin):
"""Acquisition information for each readout."""
idx: AcqIdx
"""Indices describing acquisitions (i.e. readouts)."""
acquisition_time_stamp: torch.Tensor
"""Clock time stamp. Not in s but in vendor-specific time units (e.g. 2.5ms for Siemens)"""
active_channels: torch.Tensor
"""Number of active receiver coil elements."""
available_channels: torch.Tensor
"""Number of available receiver coil elements."""
center_sample: torch.Tensor
"""Index of the readout sample corresponding to k-space center (zero indexed)."""
channel_mask: torch.Tensor
"""Bit mask indicating active coils (64*16 = 1024 bits)."""
discard_post: torch.Tensor
"""Number of readout samples to be discarded at the end (e.g. if the ADC is active during gradient events)."""
discard_pre: torch.Tensor
"""Number of readout samples to be discarded at the beginning (e.g. if the ADC is active during gradient events)"""
encoding_space_ref: torch.Tensor
"""Indexed reference to the encoding spaces enumerated in the MRD (xml) header."""
flags: torch.Tensor
"""A bit mask of common attributes applicable to individual acquisition readouts."""
measurement_uid: torch.Tensor
"""Unique ID corresponding to the readout."""
number_of_samples: torch.Tensor
"""Number of sample points per readout (readouts may have different number of sample points)."""
orientation: Rotation
"""Rotation describing the orientation of the readout, phase and slice encoding direction."""
patient_table_position: SpatialDimension[torch.Tensor]
"""Offset position of the patient table, in LPS coordinates [m]."""
physiology_time_stamp: torch.Tensor
"""Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units"""
position: SpatialDimension[torch.Tensor]
"""Center of the excited volume, in LPS coordinates relative to isocenter [m]."""
sample_time_us: torch.Tensor
"""Readout bandwidth, as time between samples [us]."""
scan_counter: torch.Tensor
"""Zero-indexed incrementing counter for readouts."""
trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists.
"""Dimensionality of the k-space trajectory vector."""
user_float: torch.Tensor
"""User-defined float parameters."""
user_int: torch.Tensor
"""User-defined int parameters."""
version: torch.Tensor
"""Major version number."""
@classmethod
def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) -> Self:
"""Read the header of a list of acquisition and store information.
Parameters
----------
acquisitions:
list of ismrmrd acquisistions to read from. Needs at least one acquisition.
"""
# Idea: create array of structs, then a struct of arrays,
# convert it into tensors to store in our dataclass.
# TODO: there might be a faster way to do this.
if len(acquisitions) == 0:
raise ValueError('Acquisition list must not be empty.')
# Creating the dtype first and casting to bytes
# is a workaround for a bug in cpython > 3.12 causing a warning
# is np.array(AcquisitionHeader) is called directly.
# also, this needs to check the dtyoe only once.
acquisition_head_dtype = np.dtype(ismrmrd.AcquisitionHeader)
headers = np.frombuffer(
np.array([memoryview(a._head).cast('B') for a in acquisitions]),
dtype=acquisition_head_dtype,
)
idx = headers['idx']
def tensor(data: np.ndarray) -> torch.Tensor:
# we have to convert first as pytoch cant create tensors from np.uint16 arrays
# we use int32 for uint16 and int64 for uint32 to fit largest values.
match data.dtype:
case np.uint16:
data = data.astype(np.int32)
case np.uint32 | np.uint64:
data = data.astype(np.int64)
# Remove any uncessary dimensions
return torch.tensor(np.squeeze(data))
def tensor_2d(data: np.ndarray) -> torch.Tensor:
# Convert tensor to torch dtypes and ensure it is atleast 2D
data_tensor = tensor(data)
# Ensure that data is (k1*k2*other, >=1)
if data_tensor.ndim == 1:
data_tensor = data_tensor[:, None]
elif data_tensor.ndim == 0:
data_tensor = data_tensor[None, None]
return data_tensor
def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]:
# Ensure spatial dimension is (k1*k2*other, 1, 3)
if data.ndim != 2:
raise ValueError('Spatial dimension is expected to be of shape (N,3)')
data = data[:, None, :]
# all spatial dimensions are float32
return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32)))
acq_idx = AcqIdx(
k1=tensor(idx['kspace_encode_step_1']),
k2=tensor(idx['kspace_encode_step_2']),
average=tensor(idx['average']),
slice=tensor(idx['slice']),
contrast=tensor(idx['contrast']),
phase=tensor(idx['phase']),
repetition=tensor(idx['repetition']),
set=tensor(idx['set']),
segment=tensor(idx['segment']),
user0=tensor(idx['user'][:, 0]),
user1=tensor(idx['user'][:, 1]),
user2=tensor(idx['user'][:, 2]),
user3=tensor(idx['user'][:, 3]),
user4=tensor(idx['user'][:, 4]),
user5=tensor(idx['user'][:, 5]),
user6=tensor(idx['user'][:, 6]),
user7=tensor(idx['user'][:, 7]),
)
acq_info = cls(
idx=acq_idx,
acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']),
active_channels=tensor_2d(headers['active_channels']),
available_channels=tensor_2d(headers['available_channels']),
center_sample=tensor_2d(headers['center_sample']),
channel_mask=tensor_2d(headers['channel_mask']),
discard_post=tensor_2d(headers['discard_post']),
discard_pre=tensor_2d(headers['discard_pre']),
encoding_space_ref=tensor_2d(headers['encoding_space_ref']),
flags=tensor_2d(headers['flags']),
measurement_uid=tensor_2d(headers['measurement_uid']),
number_of_samples=tensor_2d(headers['number_of_samples']),
orientation=Rotation.from_directions(
spatialdimension_2d(headers['slice_dir']),
spatialdimension_2d(headers['phase_dir']),
spatialdimension_2d(headers['read_dir']),
),
patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m),
physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']),
position=spatialdimension_2d(headers['position']).apply_(mm_to_m),
sample_time_us=tensor_2d(headers['sample_time_us']),
scan_counter=tensor_2d(headers['scan_counter']),
trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above
user_float=tensor_2d(headers['user_float']),
user_int=tensor_2d(headers['user_int']),
version=tensor_2d(headers['version']),
)
return acq_info