Skip to content

Commit

Permalink
Merge branch 'main' into notebooks-in-pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lrlunin authored Nov 19, 2024
2 parents beddf13 + 8d24ebb commit 279a578
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ repos:
- id: typos

- repo: https://github.com/fzimmermann89/check_all
rev: v1.0
rev: v1.1
hooks:
- id: check-init-all
args: [--double-quotes]
args: [--double-quotes, --fix]
exclude: ^tests/

- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ skip-magic-trailing-comma = false

[tool.typos.default]
locale = "en-us"
check-filename = false

[tool.typos.default.extend-words]
Reson = "Reson" # required for Proc. Intl. Soc. Mag. Reson. Med.
Reson = "Reson" # required for Proc. Intl. Soc. Mag. Reson. Med.
iy = "iy"
daa = 'daa' # required for wavelet operator
gaus = 'gaus' # required for wavelet operator
daa = "daa" # required for wavelet operator
gaus = "gaus" # required for wavelet operator
arange = "arange" # torch.arange

[tool.typos.files]
extend-exclude = [
Expand Down
4 changes: 3 additions & 1 deletion src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import mrpro.utils.slice_profiles
import mrpro.utils.typing
import mrpro.utils.unit_conversion
from mrpro.utils.fill_range import fill_range_
from mrpro.utils.smap import smap
from mrpro.utils.remove_repeat import remove_repeat
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
Expand All @@ -10,6 +11,7 @@

__all__ = [
"broadcast_right",
"fill_range_",
"reduce_view",
"remove_repeat",
"slice_profiles",
Expand All @@ -20,4 +22,4 @@
"unsqueeze_left",
"unsqueeze_right",
"zero_pad_or_crop"
]
]
24 changes: 24 additions & 0 deletions src/mrpro/utils/fill_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Fill tensor in-place along a specified dimension with increasing integers."""

import torch


def fill_range_(tensor: torch.Tensor, dim: int) -> None:
"""
Fill tensor in-place along a specified dimension with increasing integers.
Parameters
----------
tensor
The tensor to be modified in-place.
dim
The dimension along which to fill with increasing values.
"""
if not -tensor.ndim <= dim < tensor.ndim:
raise IndexError(f'Dimension {dim} is out of range for tensor with {tensor.ndim} dimensions.')

dim = dim % tensor.ndim
shape = [s if d == dim else 1 for d, s in enumerate(tensor.shape)]
values = torch.arange(tensor.size(dim), device=tensor.device).reshape(shape)
tensor[:] = values.expand_as(tensor)
23 changes: 23 additions & 0 deletions tests/utils/test_fill_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Tests for fill_range_"""

import pytest
import torch
from mrpro.utils import fill_range_


@pytest.mark.parametrize('dtype', [torch.float32, torch.int64], ids=['float32', 'int64'])
def test_fill_range(dtype):
"""Test functionality of fill_range."""
tensor = torch.zeros(3, 4, dtype=dtype)
fill_range_(tensor, dim=1)
expected = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]], dtype=tensor.dtype)
torch.testing.assert_close(tensor, expected)


def test_fill_range_dim_out_of_range():
"""Test fill_range_ with a dimension out of range."""
tensor = torch.zeros(3, 4)
with pytest.raises(IndexError, match='Dimension 2 is out of range'):
fill_range_(tensor, dim=2)
with pytest.raises(IndexError, match='Dimension -3 is out of range'):
fill_range_(tensor, dim=-3)

0 comments on commit 279a578

Please sign in to comment.