-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into notebooks-in-pre-commit
- Loading branch information
Showing
5 changed files
with
57 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
"""Fill tensor in-place along a specified dimension with increasing integers.""" | ||
|
||
import torch | ||
|
||
|
||
def fill_range_(tensor: torch.Tensor, dim: int) -> None: | ||
""" | ||
Fill tensor in-place along a specified dimension with increasing integers. | ||
Parameters | ||
---------- | ||
tensor | ||
The tensor to be modified in-place. | ||
dim | ||
The dimension along which to fill with increasing values. | ||
""" | ||
if not -tensor.ndim <= dim < tensor.ndim: | ||
raise IndexError(f'Dimension {dim} is out of range for tensor with {tensor.ndim} dimensions.') | ||
|
||
dim = dim % tensor.ndim | ||
shape = [s if d == dim else 1 for d, s in enumerate(tensor.shape)] | ||
values = torch.arange(tensor.size(dim), device=tensor.device).reshape(shape) | ||
tensor[:] = values.expand_as(tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""Tests for fill_range_""" | ||
|
||
import pytest | ||
import torch | ||
from mrpro.utils import fill_range_ | ||
|
||
|
||
@pytest.mark.parametrize('dtype', [torch.float32, torch.int64], ids=['float32', 'int64']) | ||
def test_fill_range(dtype): | ||
"""Test functionality of fill_range.""" | ||
tensor = torch.zeros(3, 4, dtype=dtype) | ||
fill_range_(tensor, dim=1) | ||
expected = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]], dtype=tensor.dtype) | ||
torch.testing.assert_close(tensor, expected) | ||
|
||
|
||
def test_fill_range_dim_out_of_range(): | ||
"""Test fill_range_ with a dimension out of range.""" | ||
tensor = torch.zeros(3, 4) | ||
with pytest.raises(IndexError, match='Dimension 2 is out of range'): | ||
fill_range_(tensor, dim=2) | ||
with pytest.raises(IndexError, match='Dimension -3 is out of range'): | ||
fill_range_(tensor, dim=-3) |