Skip to content

Commit

Permalink
Refactor AcqInfo and Header information
Browse files Browse the repository at this point in the history
ghstack-source-id: c7cf5277ab084e407c3c5d45ac3ed2145c8b976d
ghstack-comment-id: 2501006700
Pull Request resolved: #560
  • Loading branch information
fzimmermann89 committed Dec 16, 2024
1 parent 53103a5 commit 3c54fae
Show file tree
Hide file tree
Showing 14 changed files with 542 additions and 458 deletions.
183 changes: 125 additions & 58 deletions src/mrpro/data/AcqInfo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Acquisition information dataclass."""

from collections.abc import Sequence
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Literal, TypeAlias, overload

import ismrmrd
import numpy as np
Expand All @@ -26,6 +27,25 @@ def rearrange_acq_info_fields(field: object, pattern: str, **axes_lengths: dict[
return field


_convert_time_stamp_type: TypeAlias = Callable[
[
torch.Tensor,
Literal[
'acquisition_time_stamp', 'physiology_time_stamp_1', 'physiology_time_stamp_2', 'physiology_time_stamp_3'
],
],
torch.Tensor,
]


def convert_time_stamp_siemens(
timestamp: torch.Tensor,
_: str,
) -> torch.Tensor:
"""Convert Siemens time stamp to seconds."""
return timestamp.double() * 2.5e-3


@dataclass(slots=True)
class AcqIdx(MoveDataMixin):
"""Acquisition index for each readout."""
Expand Down Expand Up @@ -83,52 +103,59 @@ class AcqIdx(MoveDataMixin):


@dataclass(slots=True)
class AcqInfo(MoveDataMixin):
"""Acquisition information for each readout."""

idx: AcqIdx
"""Indices describing acquisitions (i.e. readouts)."""

acquisition_time_stamp: torch.Tensor
"""Clock time stamp. Not in s but in vendor-specific time units (e.g. 2.5ms for Siemens)"""
class UserValues(MoveDataMixin):
"""User Values used in AcqInfo."""

float1: torch.Tensor
float2: torch.Tensor
float3: torch.Tensor
float4: torch.Tensor
float5: torch.Tensor
float6: torch.Tensor
float7: torch.Tensor
float8: torch.Tensor
int1: torch.Tensor
int2: torch.Tensor
int3: torch.Tensor
int4: torch.Tensor
int5: torch.Tensor
int6: torch.Tensor
int7: torch.Tensor
int8: torch.Tensor

active_channels: torch.Tensor
"""Number of active receiver coil elements."""

available_channels: torch.Tensor
"""Number of available receiver coil elements."""
@dataclass(slots=True)
class PhysiologyTimestamps:
"""Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units."""

center_sample: torch.Tensor
"""Index of the readout sample corresponding to k-space center (zero indexed)."""
timestamp1: torch.Tensor
timestamp2: torch.Tensor
timestamp3: torch.Tensor

channel_mask: torch.Tensor
"""Bit mask indicating active coils (64*16 = 1024 bits)."""

discard_post: torch.Tensor
"""Number of readout samples to be discarded at the end (e.g. if the ADC is active during gradient events)."""
@dataclass(slots=True)
class AcqInfo(MoveDataMixin):
"""Acquisition information for each readout."""

discard_pre: torch.Tensor
"""Number of readout samples to be discarded at the beginning (e.g. if the ADC is active during gradient events)"""
idx: AcqIdx
"""Indices describing acquisitions (i.e. readouts)."""

encoding_space_ref: torch.Tensor
"""Indexed reference to the encoding spaces enumerated in the MRD (xml) header."""
acquisition_time_stamp: torch.Tensor
"""Clock time stamp. Usually in seconds (Siemens: seconds since midnight)"""

flags: torch.Tensor
"""A bit mask of common attributes applicable to individual acquisition readouts."""

measurement_uid: torch.Tensor
"""Unique ID corresponding to the readout."""

number_of_samples: torch.Tensor
"""Number of sample points per readout (readouts may have different number of sample points)."""

orientation: Rotation
"""Rotation describing the orientation of the readout, phase and slice encoding direction."""

patient_table_position: SpatialDimension[torch.Tensor]
"""Offset position of the patient table, in LPS coordinates [m]."""

physiology_time_stamp: torch.Tensor
physiology_time_stamps: PhysiologyTimestamps
"""Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units"""

position: SpatialDimension[torch.Tensor]
Expand All @@ -140,26 +167,48 @@ class AcqInfo(MoveDataMixin):
scan_counter: torch.Tensor
"""Zero-indexed incrementing counter for readouts."""

trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists.
"""Dimensionality of the k-space trajectory vector."""

user_float: torch.Tensor
"""User-defined float parameters."""

user_int: torch.Tensor
"""User-defined int parameters."""
user: UserValues
"""User defined float or int values"""

version: torch.Tensor
"""Major version number."""
@overload
@classmethod
def from_ismrmrd_acquisitions(
cls,
acquisitions: Sequence[ismrmrd.acquisition.Acquisition],
*,
additional_fields: None,
convert_time_stamp: _convert_time_stamp_type = convert_time_stamp_siemens,
) -> Self: ...

@overload
@classmethod
def from_ismrmrd_acquisitions(
cls,
acquisitions: Sequence[ismrmrd.acquisition.Acquisition],
*,
additional_fields: Sequence[str],
convert_time_stamp: _convert_time_stamp_type = convert_time_stamp_siemens,
) -> tuple[Self, tuple[torch.Tensor, ...]]: ...

@classmethod
def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) -> Self:
def from_ismrmrd_acquisitions(
cls,
acquisitions: Sequence[ismrmrd.acquisition.Acquisition],
*,
additional_fields: Sequence[str] | None = None,
convert_time_stamp: _convert_time_stamp_type = convert_time_stamp_siemens,
) -> Self | tuple[Self, tuple[torch.Tensor, ...]]:
"""Read the header of a list of acquisition and store information.
Parameters
----------
acquisitions:
acquisitions
list of ismrmrd acquisistions to read from. Needs at least one acquisition.
additional_fields
if supplied, additional information from fields with these names will be extracted from the
ismrmrd acquisitions and returned as tensors.
convert_time_stamp
function used to convert the raw time stamps to seconds.
"""
# Idea: create array of structs, then a struct of arrays,
# convert it into tensors to store in our dataclass.
Expand All @@ -169,9 +218,9 @@ def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition])
raise ValueError('Acquisition list must not be empty.')

# Creating the dtype first and casting to bytes
# is a workaround for a bug in cpython > 3.12 causing a warning
# is np.array(AcquisitionHeader) is called directly.
# also, this needs to check the dtyoe only once.
# is a workaround for a bug in cpython causing a warning
# if np.array(AcquisitionHeader) is called directly.
# also, this needs to check the dtype only once.
acquisition_head_dtype = np.dtype(ismrmrd.AcquisitionHeader)
headers = np.frombuffer(
np.array([memoryview(a._head).cast('B') for a in acquisitions]),
Expand Down Expand Up @@ -228,33 +277,51 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]:
user6=tensor(idx['user'][:, 6]),
user7=tensor(idx['user'][:, 7]),
)

user = UserValues(
tensor_2d(headers['user_float'][:, 0]),
tensor_2d(headers['user_float'][:, 1]),
tensor_2d(headers['user_float'][:, 2]),
tensor_2d(headers['user_float'][:, 3]),
tensor_2d(headers['user_float'][:, 4]),
tensor_2d(headers['user_float'][:, 5]),
tensor_2d(headers['user_float'][:, 6]),
tensor_2d(headers['user_float'][:, 7]),
tensor_2d(headers['user_int'][:, 0]),
tensor_2d(headers['user_int'][:, 1]),
tensor_2d(headers['user_int'][:, 2]),
tensor_2d(headers['user_int'][:, 3]),
tensor_2d(headers['user_int'][:, 4]),
tensor_2d(headers['user_int'][:, 5]),
tensor_2d(headers['user_int'][:, 6]),
tensor_2d(headers['user_int'][:, 7]),
)
physiology_time_stamps = PhysiologyTimestamps(
convert_time_stamp(tensor_2d(headers['physiology_time_stamp'][:, 0]), 'physiology_time_stamp_1'),
convert_time_stamp(tensor_2d(headers['physiology_time_stamp'][:, 1]), 'physiology_time_stamp_2'),
convert_time_stamp(tensor_2d(headers['physiology_time_stamp'][:, 2]), 'physiology_time_stamp_3'),
)
acq_info = cls(
idx=acq_idx,
acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']),
active_channels=tensor_2d(headers['active_channels']),
available_channels=tensor_2d(headers['available_channels']),
center_sample=tensor_2d(headers['center_sample']),
channel_mask=tensor_2d(headers['channel_mask']),
discard_post=tensor_2d(headers['discard_post']),
discard_pre=tensor_2d(headers['discard_pre']),
encoding_space_ref=tensor_2d(headers['encoding_space_ref']),
acquisition_time_stamp=convert_time_stamp(
tensor_2d(headers['acquisition_time_stamp']), 'acquisition_time_stamp'
),
flags=tensor_2d(headers['flags']),
measurement_uid=tensor_2d(headers['measurement_uid']),
number_of_samples=tensor_2d(headers['number_of_samples']),
orientation=Rotation.from_directions(
spatialdimension_2d(headers['slice_dir']),
spatialdimension_2d(headers['phase_dir']),
spatialdimension_2d(headers['read_dir']),
),
patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m),
physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']),
position=spatialdimension_2d(headers['position']).apply_(mm_to_m),
sample_time_us=tensor_2d(headers['sample_time_us']),
scan_counter=tensor_2d(headers['scan_counter']),
trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above
user_float=tensor_2d(headers['user_float']),
user_int=tensor_2d(headers['user_int']),
version=tensor_2d(headers['version']),
user=user,
physiology_time_stamps=physiology_time_stamps,
)
return acq_info

if additional_fields is None:
return acq_info
else:
additional_values = tuple(tensor_2d(headers[field]) for field in additional_fields)
return acq_info, additional_values
38 changes: 30 additions & 8 deletions src/mrpro/data/KData.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mrpro.data.acq_filters import has_n_coils, is_image_acquisition
from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields
from mrpro.data.EncodingLimits import Limits
from mrpro.data.enums import AcqFlags
from mrpro.data.KHeader import KHeader
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape
Expand Down Expand Up @@ -136,17 +137,20 @@ def from_file(

kdata = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions])

acqinfo = AcqInfo.from_ismrmrd_acquisitions(acquisitions)
acq_info, (k0_center, n_k0_tensor, discard_pre, discard_post) = AcqInfo.from_ismrmrd_acquisitions(
acquisitions,
additional_fields=('center_sample', 'number_of_samples', 'discard_pre', 'discard_post'),
)

if len(torch.unique(acqinfo.idx.user5)) > 1:
if len(torch.unique(acq_info.idx.user5)) > 1:
warnings.warn(
'The Siemens to ismrmrd converter currently (ab)uses '
'the user 5 indices for storing the kspace center line number.\n'
'User 5 indices will be ignored',
stacklevel=1,
)

if len(torch.unique(acqinfo.idx.user6)) > 1:
if len(torch.unique(acq_info.idx.user6)) > 1:
warnings.warn(
'The Siemens to ismrmrd converter currently (ab)uses '
'the user 6 indices for storing the kspace center partition number.\n'
Expand All @@ -157,7 +161,7 @@ def from_file(
# Raises ValueError if required fields are missing in the header
kheader = KHeader.from_ismrmrd(
ismrmrd_header,
acqinfo,
acq_info,
defaults={
'datetime': modification_time, # use the modification time of the dataset as fallback
'trajectory': ktrajectory,
Expand All @@ -171,9 +175,9 @@ def from_file(
# (number_of_samples, center_sample) of (100, 20) (e.g. partial Fourier in the negative k0 direction) and
# (100, 80) (e.g. partial Fourier in the positive k0 direction) then this should lead to encoding limits of
# [min=0, max=159, center=80]
max_center_sample = int(torch.max(kheader.acq_info.center_sample))
max_pos_k0_extend = int(torch.max(kheader.acq_info.number_of_samples - kheader.acq_info.center_sample))
kheader.encoding_limits.k0 = Limits(0, max_center_sample + max_pos_k0_extend - 1, max_center_sample)
max_center_sample = int(torch.max(k0_center))
max_positive_k0_extend = int(torch.max(n_k0_tensor - k0_center))
kheader.encoding_limits.k0 = Limits(0, max_center_sample + max_positive_k0_extend - 1, max_center_sample)

# Sort and reshape the kdata and the acquisistion info according to the indices.
# within "other", the aquisistions are sorted in the order determined by KDIM_SORT_LABELS.
Expand Down Expand Up @@ -232,13 +236,31 @@ def from_file(
else field
)
kdata = rearrange(kdata[sort_idx], '(other k2 k1) coils k0 -> other coils k2 k1 k0', k1=n_k1, k2=n_k2)
k0_center = rearrange(k0_center[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2)

# Calculate trajectory and check if it matches the kdata shape
match ktrajectory:
case KTrajectoryIsmrmrd():
ktrajectory_final = ktrajectory(acquisitions).sort_and_reshape(sort_idx, n_k2, n_k1)
case KTrajectoryCalculator():
ktrajectory_or_rawshape = ktrajectory(kheader)
reversed_readout_mask = (kheader.acq_info.flags[..., 0] & AcqFlags.ACQ_IS_REVERSE.value).bool()
n_k0_unique = torch.unique(n_k0_tensor)
if len(n_k0_unique) > 1:
raise ValueError(
'Trajectory can only be calculated for constant number of readout samples.\n'
f'Got unique values {list(n_k0_unique)}'
)
ktrajectory_or_rawshape = ktrajectory(
n_k0=int(n_k0_unique[0]),
k0_center=k0_center,
k1_idx=kheader.acq_info.idx.k1,
k1_center=kheader.encoding_limits.k1.center,
k2_idx=kheader.acq_info.idx.k2,
k2_center=kheader.encoding_limits.k2.center,
reversed_readout_mask=reversed_readout_mask,
encoding_matrix=kheader.encoding_matrix,
)

if isinstance(ktrajectory_or_rawshape, KTrajectoryRawShape):
ktrajectory_final = ktrajectory_or_rawshape.sort_and_reshape(sort_idx, n_k2, n_k1)
else:
Expand Down
6 changes: 0 additions & 6 deletions src/mrpro/data/_kdata/KDataRemoveOsMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,5 @@ def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor:

# Adapt header parameters
header = deepcopy(self.header)
header.acq_info.center_sample -= start_cropped_readout
header.acq_info.number_of_samples[:] = cropped_data.shape[-1]
header.encoding_matrix.x = cropped_data.shape[-1]

header.acq_info.discard_post = (header.acq_info.discard_post * x_ratio).to(torch.int32)
header.acq_info.discard_pre = (header.acq_info.discard_pre * x_ratio).to(torch.int32)

return type(self)(header, cropped_data, cropped_traj)
Loading

0 comments on commit 3c54fae

Please sign in to comment.