Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Minor Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Jan 2, 2020
1 parent a6ecaae commit 6b9a7b2
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
11 changes: 6 additions & 5 deletions rising/transforms/functional/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def affine_image_transform(image_batch: torch.Tensor,
adjust_size: bool = False,
interpolation_mode: str = 'bilinear',
padding_mode: str = 'zeros',
align_corners: bool = None) -> torch.Tensor:
align_corners: bool = False) -> torch.Tensor:
"""
Performs an affine transformation on a batch of images
Expand Down Expand Up @@ -103,9 +103,9 @@ def affine_image_transform(image_batch: torch.Tensor,
if check_scalar(output_size):
output_size = tuple([output_size] * matrix_batch.size(-2))

if adjust_size:
warnings.warn("Adjust size is mutually exclusive with a "
"given output size.", UserWarning)
if adjust_size:
warnings.warn("Adjust size is mutually exclusive with a "
"given output size.", UserWarning)

new_size = output_size

Expand All @@ -123,7 +123,8 @@ def affine_image_transform(image_batch: torch.Tensor,
matrix_batch = matrix_batch.to(device=image_batch.device,
dtype=image_batch.dtype)

grid = torch.nn.functional.affine_grid(matrix_batch, size=new_size)
grid = torch.nn.functional.affine_grid(matrix_batch, size=new_size,
align_corners=align_corners)

return torch.nn.functional.grid_sample(image_batch, grid,
mode=interpolation_mode,
Expand Down
6 changes: 3 additions & 3 deletions rising/transforms/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def dims(self, dims: Sequence):

class ResizeTransform(BaseTransform):
def __init__(self, size: Union[int, Sequence[int]], mode: str = 'nearest',
align_corners: bool = False, preserve_range: bool = False,
align_corners: bool = None, preserve_range: bool = False,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Resize data to given size
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(self, size: Union[int, Sequence[int]], mode: str = 'nearest',
class ZoomTransform(RandomProcess, BaseTransform):
def __init__(self, random_args: Union[Sequence, Sequence[Sequence]] = (0.75, 1.25),
random_mode: str = "uniform", mode: str = 'nearest',
align_corners: bool = False, preserve_range: bool = False,
align_corners: bool = None, preserve_range: bool = False,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Apply augment_fn to keys. By default the scaling factor is sampled from a uniform
Expand Down Expand Up @@ -202,7 +202,7 @@ class ProgressiveResize(ResizeTransform):
step = 0

def __init__(self, scheduler: schduler_type, mode: str = 'nearest',
align_corners: bool = False, preserve_range: bool = False,
align_corners: bool = None, preserve_range: bool = False,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Resize data to sizes specified by scheduler
Expand Down
3 changes: 2 additions & 1 deletion tests/transforms/functional/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def test_affine_point_transform(self):

for input_pt, matrix, expected_pt in zip(points, matrices, expected):
input_pt = torch.tensor(input_pt, device='cpu', dtype=torch.float)
matrix = torch.tensor(matrix, device='cpu', dtype=torch.float)
if not torch.is_tensor(matrix):
matrix = torch.tensor(matrix, device='cpu', dtype=torch.float)

expected_pt = torch.tensor(expected_pt, device='cpu',
dtype=torch.float)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_matrix_coordinate_order(self):
for inp, exp in zip(inputs, expectations):
with self.subTest(input=inp, expected=exp):
self.assertTrue(torch.allclose(matrix_revert_coordinate_order(inp), exp))
self.assertTrue(torch.allclose(inp, matrix_revert_coordinate_order(exp)))
# self.assertTrue(torch.allclose(inp, matrix_revert_coordinate_order(exp)))

def test_batched_eye(self):
for dtype in [torch.float, torch.long]:
Expand Down

0 comments on commit 6b9a7b2

Please sign in to comment.