Skip to content

Commit

Permalink
Merge branch 'main' into get_diff_coils
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 authored Nov 12, 2024
2 parents 41be9ae + 84b983c commit b0b5326
Show file tree
Hide file tree
Showing 19 changed files with 155 additions and 76 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/deployment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,15 @@ jobs:
run: |
VERSION=${{ needs.build-testpypi-package.outputs.version }}
SUFFIX=${{ needs.build-testpypi-package.outputs.suffix }}
python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/
for i in {1..3}; do
if python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/; then
echo "Package installed successfully."
break
else
echo "Attempt $i failed. Retrying in 10 seconds..."
sleep 10
fi
done
build-pypi-package:
name: Build Package for PyPI
Expand Down
17 changes: 8 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ repos:
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: check-docstring-first
- id: check-merge-conflict
- id: check-yaml
- id: check-toml
Expand Down Expand Up @@ -54,11 +53,11 @@ repos:
- "--extra-index-url=https://pypi.python.org/simple"

ci:
autofix_commit_msg: |
[pre-commit] auto fixes from pre-commit hooks
autofix_prs: false
autoupdate_branch: ''
autoupdate_commit_msg: '[pre-commit] pre-commit autoupdate'
autoupdate_schedule: monthly
skip: [mypy]
submodules: false
autofix_commit_msg: |
[pre-commit] auto fixes from pre-commit hooks
autofix_prs: false
autoupdate_branch: ""
autoupdate_commit_msg: "[pre-commit] pre-commit autoupdate"
autoupdate_schedule: monthly
skip: [mypy]
submodules: false
8 changes: 4 additions & 4 deletions examples/direct_reconstruction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
"\n",
"import requests\n",
"\n",
"with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n",
" response = requests.get(zenodo_url + fname, timeout=30)\n",
" data_file.write(response.content)\n",
" data_file.flush()"
"data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n",
"response = requests.get(zenodo_url + fname, timeout=30)\n",
"data_file.write(response.content)\n",
"data_file.flush()"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions examples/direct_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

import requests

with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:
response = requests.get(zenodo_url + fname, timeout=30)
data_file.write(response.content)
data_file.flush()
data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')
response = requests.get(zenodo_url + fname, timeout=30)
data_file.write(response.content)
data_file.flush()

# %% [markdown]
# ### Image reconstruction
Expand Down
8 changes: 4 additions & 4 deletions examples/iterative_sense_reconstruction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
"\n",
"import requests\n",
"\n",
"with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n",
" response = requests.get(zenodo_url + fname, timeout=30)\n",
" data_file.write(response.content)\n",
" data_file.flush()"
"data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n",
"response = requests.get(zenodo_url + fname, timeout=30)\n",
"data_file.write(response.content)\n",
"data_file.flush()"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions examples/iterative_sense_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

import requests

with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:
response = requests.get(zenodo_url + fname, timeout=30)
data_file.write(response.content)
data_file.flush()
data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')
response = requests.get(zenodo_url + fname, timeout=30)
data_file.write(response.content)
data_file.flush()

# %% [markdown]
# ### Image reconstruction
Expand Down
22 changes: 10 additions & 12 deletions examples/pulseq_2d_radial_golden_angle.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@
"cell_type": "code",
"execution_count": null,
"id": "d16f41f1",
"metadata": {
"lines_to_next_cell": 2
},
"metadata": {},
"outputs": [],
"source": [
"# define zenodo records URL and create a temporary directory and h5-file\n",
"zenodo_url = 'https://zenodo.org/records/10854057/files/'\n",
"fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'"
"fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'\n",
"data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')"
]
},
{
Expand All @@ -51,10 +50,9 @@
"outputs": [],
"source": [
"# Download raw data using requests\n",
"with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n",
" response = requests.get(zenodo_url + fname, timeout=30)\n",
" data_file.write(response.content)\n",
" data_file.flush()"
"response = requests.get(zenodo_url + fname, timeout=30)\n",
"data_file.write(response.content)\n",
"data_file.flush()"
]
},
{
Expand Down Expand Up @@ -127,10 +125,10 @@
"# download the sequence file from zenodo\n",
"zenodo_url = 'https://zenodo.org/records/10868061/files/'\n",
"seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq'\n",
"with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file:\n",
" response = requests.get(zenodo_url + seq_fname, timeout=30)\n",
" seq_file.write(response.content)\n",
" seq_file.flush()"
"seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq')\n",
"response = requests.get(zenodo_url + seq_fname, timeout=30)\n",
"seq_file.write(response.content)\n",
"seq_file.flush()"
]
},
{
Expand Down
17 changes: 8 additions & 9 deletions examples/pulseq_2d_radial_golden_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
# define zenodo records URL and create a temporary directory and h5-file
zenodo_url = 'https://zenodo.org/records/10854057/files/'
fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'

data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')

# %%
# Download raw data using requests
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:
response = requests.get(zenodo_url + fname, timeout=30)
data_file.write(response.content)
data_file.flush()
response = requests.get(zenodo_url + fname, timeout=30)
data_file.write(response.content)
data_file.flush()

# %% [markdown]
# ### Image reconstruction using KTrajectoryIsmrmrd
Expand Down Expand Up @@ -63,10 +62,10 @@
# download the sequence file from zenodo
zenodo_url = 'https://zenodo.org/records/10868061/files/'
seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq'
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file:
response = requests.get(zenodo_url + seq_fname, timeout=30)
seq_file.write(response.content)
seq_file.flush()
seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq')
response = requests.get(zenodo_url + seq_fname, timeout=30)
seq_file.write(response.content)
seq_file.flush()

# %%
# Read raw data and calculate trajectory using KTrajectoryPulseq
Expand Down
8 changes: 4 additions & 4 deletions examples/regularized_iterative_sense_reconstruction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
"\n",
"import requests\n",
"\n",
"with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n",
" response = requests.get(zenodo_url + fname, timeout=30)\n",
" data_file.write(response.content)\n",
" data_file.flush()"
"data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n",
"response = requests.get(zenodo_url + fname, timeout=30)\n",
"data_file.write(response.content)\n",
"data_file.flush()"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions examples/regularized_iterative_sense_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

import requests

with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:
response = requests.get(zenodo_url + fname, timeout=30)
data_file.write(response.content)
data_file.flush()
data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')
response = requests.get(zenodo_url + fname, timeout=30)
data_file.write(response.content)
data_file.flush()

# %% [markdown]
# ### Image reconstruction
Expand Down
1 change: 1 addition & 0 deletions examples/ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ lint.extend-ignore = [
"T20", #print
"E402", #module-import-not-at-top-of-file
"S101", #assert
"SIM115", #context manager for opening files
]
2 changes: 1 addition & 1 deletion src/mrpro/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.241029
0.241112
4 changes: 1 addition & 3 deletions src/mrpro/data/MoveDataMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,9 @@ def apply_(
----------
function
The function to apply to all fields. None is interpreted as a no-op.
memo
A dictionary to keep track of objects that the function has already been applied to,
A dictionary to keep track of objects that the function has already been applied to,
to avoid multiple applications. This is useful if the object has a circular reference.
recurse
If True, the function will be applied to all children that are MoveDataMixin instances.
"""
Expand Down
5 changes: 4 additions & 1 deletion src/mrpro/data/_kdata/KDataSelectMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing_extensions import Self

from mrpro.data._kdata.KDataProtocol import _KDataProtocol
from mrpro.data.Rotation import Rotation


class KDataSelectMixin(_KDataProtocol):
Expand Down Expand Up @@ -50,7 +51,9 @@ def select_other_subset(
other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0)

# Adapt header
kheader.acq_info.apply_(lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor) else field)
kheader.acq_info.apply_(
lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field
)

# Select data
kdat = self.data[other_idx, ...]
Expand Down
17 changes: 8 additions & 9 deletions src/mrpro/data/_kdata/KDataSplitMixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Mixin class to split KData into other subsets."""

from typing import Literal, TypeVar
from typing import Literal, TypeVar, cast

import torch
from einops import rearrange, repeat
Expand All @@ -10,10 +10,8 @@
from mrpro.data.AcqInfo import rearrange_acq_info_fields
from mrpro.data.EncodingLimits import Limits
from mrpro.data.Rotation import Rotation
from mrpro.data.SpatialDimension import SpatialDimension

T = TypeVar('T', torch.Tensor, Rotation, SpatialDimension)

RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation)

class KDataSplitMixin(_KDataProtocol):
"""Split KData into other subsets."""
Expand Down Expand Up @@ -59,8 +57,9 @@ def _split_k2_or_k1_into_other(
def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor:
return dat_traj[:, :, :, split_idx, :]

def split_acq_info(acq_info: T) -> T:
return acq_info[:, :, split_idx, ...]
def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor:
# cast due to https://github.com/python/mypy/issues/10817
return cast(RotationOrTensor, acq_info[:, :, split_idx, ...])

# Rearrange other_split and k1 dimension
rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0'
Expand All @@ -72,8 +71,8 @@ def split_acq_info(acq_info: T) -> T:
def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor:
return dat_traj[:, :, split_idx, :, :]

def split_acq_info(acq_info: T) -> T:
return acq_info[:, split_idx, ...]
def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor:
return cast(RotationOrTensor, acq_info[:, split_idx, ...])

# Rearrange other_split and k1 dimension
rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0'
Expand Down Expand Up @@ -101,7 +100,7 @@ def split_acq_info(acq_info: T) -> T:
# Update shape of acquisition info index
kheader.acq_info.apply_(
lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info)
if isinstance(field, T.__constraints__)
if isinstance(field, Rotation | torch.Tensor)
else field
)

Expand Down
5 changes: 4 additions & 1 deletion src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import mrpro.utils.slice_profiles
import mrpro.utils.typing
import mrpro.utils.unit_conversion
from mrpro.utils.smap import smap
from mrpro.utils.remove_repeat import remove_repeat
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
from mrpro.utils.split_idx import split_idx
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view
import mrpro.utils.unit_conversion

__all__ = [
"broadcast_right",
"reduce_view",
"remove_repeat",
"slice_profiles",
"smap",
Expand Down
32 changes: 32 additions & 0 deletions src/mrpro/utils/reshape.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tensor reshaping utilities."""

from collections.abc import Sequence

import torch


Expand Down Expand Up @@ -67,3 +69,33 @@ def broadcast_right(*x: torch.Tensor) -> tuple[torch.Tensor, ...]:
max_dim = max(el.ndim for el in x)
unsqueezed = torch.broadcast_tensors(*(unsqueeze_right(el, max_dim - el.ndim) for el in x))
return unsqueezed


def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torch.Tensor:
"""Reduce expanded dimensions in a view to singletons.
Reduce either all or specific dimensions to a singleton if it
points to the same memory address.
This undoes expand.
Parameters
----------
x
input tensor
dim
only reduce expanded dimensions in the specified dimensions.
If None, reduce all expanded dimensions.
"""
if dim is None:
dim_: Sequence[int] = range(x.ndim)
elif isinstance(dim, Sequence):
dim_ = [d % x.ndim for d in dim]
else:
dim_ = [dim % x.ndim]

stride = x.stride()
newsize = [
1 if stride == 0 and d in dim_ else oldsize
for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True))
]
return torch.as_strided(x, newsize, stride)
Loading

0 comments on commit b0b5326

Please sign in to comment.