-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
38722bf
commit 8d24ebb
Showing
5 changed files
with
57 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
"""Fill tensor in-place along a specified dimension with increasing integers.""" | ||
|
||
import torch | ||
|
||
|
||
def fill_range_(tensor: torch.Tensor, dim: int) -> None: | ||
""" | ||
Fill tensor in-place along a specified dimension with increasing integers. | ||
Parameters | ||
---------- | ||
tensor | ||
The tensor to be modified in-place. | ||
dim | ||
The dimension along which to fill with increasing values. | ||
""" | ||
if not -tensor.ndim <= dim < tensor.ndim: | ||
raise IndexError(f'Dimension {dim} is out of range for tensor with {tensor.ndim} dimensions.') | ||
|
||
dim = dim % tensor.ndim | ||
shape = [s if d == dim else 1 for d, s in enumerate(tensor.shape)] | ||
values = torch.arange(tensor.size(dim), device=tensor.device).reshape(shape) | ||
tensor[:] = values.expand_as(tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""Tests for fill_range_""" | ||
|
||
import pytest | ||
import torch | ||
from mrpro.utils import fill_range_ | ||
|
||
|
||
@pytest.mark.parametrize('dtype', [torch.float32, torch.int64], ids=['float32', 'int64']) | ||
def test_fill_range(dtype): | ||
"""Test functionality of fill_range.""" | ||
tensor = torch.zeros(3, 4, dtype=dtype) | ||
fill_range_(tensor, dim=1) | ||
expected = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]], dtype=tensor.dtype) | ||
torch.testing.assert_close(tensor, expected) | ||
|
||
|
||
def test_fill_range_dim_out_of_range(): | ||
"""Test fill_range_ with a dimension out of range.""" | ||
tensor = torch.zeros(3, 4) | ||
with pytest.raises(IndexError, match='Dimension 2 is out of range'): | ||
fill_range_(tensor, dim=2) | ||
with pytest.raises(IndexError, match='Dimension -3 is out of range'): | ||
fill_range_(tensor, dim=-3) |
8d24ebb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report
8d24ebb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report
8d24ebb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report