Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
add per sample option for affine transforms
Browse files Browse the repository at this point in the history
mibaumgartner committed May 16, 2020
1 parent d810446 commit 5831199
Showing 2 changed files with 29 additions and 4 deletions.
27 changes: 23 additions & 4 deletions rising/transforms/affine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Union, Iterable, Any, Optional, Tuple
from typing import Sequence, Union, Any, Optional, Tuple

import torch

@@ -37,6 +37,7 @@ def __init__(self,
padding_mode: str = 'zeros',
align_corners: bool = False,
reverse_order: bool = False,
per_sample: bool = True,
**kwargs):
"""
Args:
@@ -66,6 +67,8 @@ def __init__(self,
transformation to conform to the pytorch convention:
transformation params order [W,H(,D)] and
batch order [(D,)H,W]
per_sample: sample different values for each element in the batch.
The transform is still applied in a batched wise fashion.
**kwargs: additional keyword arguments passed to the
affine transform
"""
@@ -80,6 +83,7 @@ def __init__(self,
self.padding_mode = padding_mode
self.align_corners = align_corners
self.reverse_order = reverse_order
self.per_sample = per_sample

def assemble_matrix(self, **data) -> torch.Tensor:
"""
@@ -314,6 +318,7 @@ def __init__(self,
padding_mode: str = 'zeros',
align_corners: bool = False,
reverse_order: bool = False,
per_sample: bool = True,
**kwargs,
):
"""
@@ -364,6 +369,8 @@ def __init__(self,
transformation to conform to the pytorch convention:
transformation params order [W,H(,D)] and
batch order [(D,)H,W]
per_sample: sample different values for each element in the batch.
The transform is still applied in a batched wise fashion.
**kwargs: additional keyword arguments passed to the
affine transform
"""
@@ -373,6 +380,7 @@ def __init__(self,
padding_mode=padding_mode,
align_corners=align_corners,
reverse_order=reverse_order,
per_sample=per_sample,
**kwargs)
self.register_sampler('scale', scale)
self.register_sampler('rotation', rotation)
@@ -406,12 +414,23 @@ def assemble_matrix(self, **data) -> torch.Tensor:
device=device, dtype=dtype, image_transform=self.image_transform)
return self.matrix

def sample_for_batch(self, name: str, batchsize: int) -> Optional[Sequence[Any]]:
def sample_for_batch(self, name: str, batchsize: int) -> Optional[
Union[Any, Sequence[Any]]]:
"""
Sample elements for batch
Args:
name: name of parameter
batchsize: batch size
Returns:
Optional[Union[Any, Sequence[Any]]]: sampled elements
"""
elem = getattr(self, name)
if elem is not None:
if elem is not None and self.per_sample:
return [elem] + [getattr(self, name) for _ in range(batchsize - 1)]
else:
return None
return elem # either a single scalar value or None


class Rotate(BaseAffine):
6 changes: 6 additions & 0 deletions tests/transforms/test_affine.py
Original file line number Diff line number Diff line change
@@ -121,6 +121,9 @@ def test_affine_subtypes(self):
Translate(10, adjust_size=True, unit="pixel"),
Translate(10, adjust_size=False, unit="pixel"),
Translate([5, 10], adjust_size=False, unit="pixel"),
Scale(5, adjust_size=False, per_sample=False),
Rotate([90], adjust_size=False, degree=True, per_sample=False),
Translate(10, adjust_size=False, unit="pixel", per_sample=False),
]

expected_sizes = [
@@ -137,6 +140,9 @@ def test_affine_subtypes(self):
(25, 30),
(25, 30),
(25, 30),
(25, 30),
(25, 30),
(25, 30),
]

for trafo, expected_size in zip(trafos, expected_sizes):

0 comments on commit 5831199

Please sign in to comment.