Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IHeader #591

Draft
wants to merge 11 commits into
base: gh/fzimmermann89/36/head
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 147 additions & 48 deletions src/mrpro/data/IHeader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,127 @@

import dataclasses
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field

import numpy as np
import torch
from einops import repeat
from pydicom.dataset import Dataset
from pydicom.tag import Tag, TagType
from typing_extensions import Self

from mrpro.data.KHeader import KHeader
from mrpro.data.MoveDataMixin import MoveDataMixin
from mrpro.data.Rotation import Rotation
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.utils.remove_repeat import remove_repeat
from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues
from mrpro.utils.unit_conversion import deg_to_rad, mm_to_m, ms_to_s

from .AcqInfo import PhysiologyTimestamps

MISC_TAGS = {'TimeAfterStart': 0x00191016}


def _int_factory() -> torch.Tensor:
return torch.zeros(1, 1, dtype=torch.int64)


@dataclass(slots=True)
class ImageIdx(MoveDataMixin):
"""Acquisition index for each readout."""

average: torch.Tensor = field(default_factory=_int_factory)
"""Signal average."""

slice: torch.Tensor = field(default_factory=_int_factory)
"""Slice number (multi-slice 2D)."""

contrast: torch.Tensor = field(default_factory=_int_factory)
"""Echo number in multi-echo."""

phase: torch.Tensor = field(default_factory=_int_factory)
"""Cardiac phase."""

repetition: torch.Tensor = field(default_factory=_int_factory)
"""Counter in repeated/dynamic acquisitions."""

set: torch.Tensor = field(default_factory=_int_factory)
"""Sets of different preparation, e.g. flow encoding, diffusion weighting."""

user0: torch.Tensor = field(default_factory=_int_factory)
"""User index 0."""

user1: torch.Tensor = field(default_factory=_int_factory)
"""User index 1."""

user2: torch.Tensor = field(default_factory=_int_factory)
"""User index 2."""

user3: torch.Tensor = field(default_factory=_int_factory)
"""User index 3."""

user4: torch.Tensor = field(default_factory=_int_factory)
"""User index 4."""

user5: torch.Tensor = field(default_factory=_int_factory)
"""User index 5."""

user6: torch.Tensor = field(default_factory=_int_factory)
"""User index 6."""

user7: torch.Tensor = field(default_factory=_int_factory)
"""User index 7."""


@dataclass(slots=True)
class IHeader(MoveDataMixin):
"""MR image data header."""

# ToDo: decide which attributes to store in the header
fov: SpatialDimension[float]
"""Field of view [m]."""

te: torch.Tensor | None
te: torch.Tensor | None = None
"""Echo time [s]."""

ti: torch.Tensor | None
ti: torch.Tensor | None = None
"""Inversion time [s]."""

fa: torch.Tensor | None
fa: torch.Tensor | None = None
"""Flip angle [rad]."""

tr: torch.Tensor | None
tr: torch.Tensor | None = None
"""Repetition time [s]."""

misc: dict = dataclasses.field(default_factory=dict)
_misc: dict = dataclasses.field(default_factory=dict)
"""Dictionary with miscellaneous parameters."""

position: SpatialDimension[torch.Tensor] = field(
default_factory=lambda: SpatialDimension(
torch.zeros(1, 1, 1, 1, 1),
torch.zeros(1, 1, 1, 1, 1),
torch.zeros(1, 1, 1, 1, 1),
)
)
"""Center of the excited volume"""

orientation: Rotation = field(default_factory=lambda: Rotation.identity((1, 1, 1, 1, 1)))
"""Orientation of the image"""

patient_table_position: SpatialDimension[torch.Tensor] = field(
default_factory=lambda: SpatialDimension(
torch.zeros(1, 1, 1, 1, 1),
torch.zeros(1, 1, 1, 1, 1),
torch.zeros(1, 1, 1, 1, 1),
)
)
"""Offset position of the patient table"""

acquisition_time_stamp: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, 1, 1, 1))

physiology_time_stamps: PhysiologyTimestamps = field(default_factory=PhysiologyTimestamps)

ImageIdx: ImageIdx = field(default_factory=ImageIdx)

@classmethod
def from_kheader(cls, kheader: KHeader) -> Self:
"""Create IHeader object from KHeader object.
Expand Down Expand Up @@ -74,53 +156,70 @@ def get_item(dataset: Dataset, name: TagType):
else:
raise ValueError(f'Item {name} found {len(found_item)} times.')

def get_items_from_all_dicoms(name: TagType):
"""Get list of items for all dataset objects in the list."""
def get_items_from_dicom_datasets(name: TagType) -> list:
"""Get list of items for all datasets in dicom_datasets."""
return [get_item(ds, name) for ds in dicom_datasets]

def get_float_items_from_all_dicoms(name: TagType):
"""Convert items to float."""
items = get_items_from_all_dicoms(name)
return [float(val) if val is not None else None for val in items]

def make_unique_tensor(values: Sequence[float]) -> torch.Tensor | None:
"""If all the values are the same only return one."""
if any(val is None for val in values):
def get_float_items_from_dicom_datasets(name: TagType) -> list[float]:
"""Get float items from all dataset in dicom_datasets."""
items = []
for item in get_items_from_dicom_datasets(name):
try:
items.append(float(item))
except (TypeError, ValueError):
# None or invalid value
items.append(float('nan'))
return items

def as_5d_tensor(values: Sequence[float]) -> torch.Tensor:
"""Convert a list of values to a 5d tensor."""
tensor = torch.as_tensor(values)
tensor = repeat(tensor, 'values-> values 1 1 1 1')
tensor = remove_repeat(tensor, 1e-12)
return tensor

def all_nan_to_none(tensor: torch.Tensor) -> torch.Tensor | None:
"""If all values are nan, return None."""
if torch.isnan(tensor).all():
return None
elif len(np.unique(values)) == 1:
return torch.as_tensor([values[0]])
else:
return torch.as_tensor(values)

# Conversion functions for units
def ms_to_s(ms: torch.Tensor | None) -> torch.Tensor | None:
return None if ms is None else ms / 1000

def deg_to_rad(deg: torch.Tensor | None) -> torch.Tensor | None:
return None if deg is None else torch.deg2rad(deg)

fa = deg_to_rad(make_unique_tensor(get_float_items_from_all_dicoms('FlipAngle')))
ti = ms_to_s(make_unique_tensor(get_float_items_from_all_dicoms('InversionTime')))
tr = ms_to_s(make_unique_tensor(get_float_items_from_all_dicoms('RepetitionTime')))

# get echo time(s). Some scanners use 'EchoTime', some use 'EffectiveEchoTime'
te_list = get_float_items_from_all_dicoms('EchoTime')
if all(val is None for val in te_list): # check if all entries are None
te_list = get_float_items_from_all_dicoms('EffectiveEchoTime')
te = ms_to_s(make_unique_tensor(te_list))

fov_x_mm = get_float_items_from_all_dicoms('Rows')[0] * float(get_items_from_all_dicoms('PixelSpacing')[0][0])
fov_y_mm = get_float_items_from_all_dicoms('Columns')[0] * float(
get_items_from_all_dicoms('PixelSpacing')[0][1],
)
fov_z_mm = get_float_items_from_all_dicoms('SliceThickness')[0]
fov = SpatialDimension(fov_x_mm, fov_y_mm, fov_z_mm) / 1000 # convert to m
return tensor

fa = all_nan_to_none(deg_to_rad(as_5d_tensor(get_float_items_from_dicom_datasets('FlipAngle'))))
ti = all_nan_to_none(ms_to_s(as_5d_tensor(get_float_items_from_dicom_datasets('InversionTime'))))
tr = all_nan_to_none(ms_to_s(as_5d_tensor(get_float_items_from_dicom_datasets('RepetitionTime'))))

te_list = get_float_items_from_dicom_datasets('EchoTime')
if all(val is None for val in te_list):
# if all 'EchoTime' entries are None, try 'EffectiveEchoTime',
# which is used by some scanners
te_list = get_float_items_from_dicom_datasets('EffectiveEchoTime')
te = all_nan_to_none(ms_to_s(as_5d_tensor(te_list)))

try:
fov_x = mm_to_m(
get_float_items_from_dicom_datasets('Rows')[0]
* float(get_items_from_dicom_datasets('PixelSpacing')[0][0])
)
except (TypeError, ValueError):
fov_x = float('nan')
try:
fov_y = mm_to_m(
get_float_items_from_dicom_datasets('Columns')[0]
* float(get_items_from_dicom_datasets('PixelSpacing')[0][1])
)
except (TypeError, ValueError):
fov_y = float('nan')
try:
fov_z = mm_to_m(get_float_items_from_dicom_datasets('SliceThickness')[0])
except (TypeError, ValueError):
fov_z = float('nan')
fov = SpatialDimension(fov_z, fov_y, fov_x)

# Get misc parameters
misc = {}
for name in MISC_TAGS:
misc[name] = make_unique_tensor(get_float_items_from_all_dicoms(MISC_TAGS[name]))
return cls(fov=fov, te=te, ti=ti, fa=fa, tr=tr, misc=misc)
misc[name] = as_5d_tensor(get_float_items_from_dicom_datasets(MISC_TAGS[name]))
return cls(fov=fov, te=te, ti=ti, fa=fa, tr=tr, _misc=misc)

def __repr__(self):
"""Representation method for IHeader class."""
Expand Down
Loading