From 1520bddd9f3676707d8c2b35eb075beb59aac7ce Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 16 Feb 2020 18:38:31 +0100 Subject: [PATCH 01/20] remove inverting of axis, start refactoring, add transformation example noebook --- notebooks/transformations.ipynb | 270 +++++++++++++++++++++ rising/transforms/__init__.py | 1 + rising/transforms/affine.py | 105 +++++++- rising/transforms/functional/affine.py | 54 ++--- rising/utils/affine.py | 55 ++--- tests/transforms/functional/test_affine.py | 51 ++-- tests/transforms/test_affine.py | 29 ++- tests/utils/test_affine.py | 53 ++-- 8 files changed, 491 insertions(+), 127 deletions(-) create mode 100644 notebooks/transformations.ipynb diff --git a/notebooks/transformations.ipynb b/notebooks/transformations.ipynb new file mode 100644 index 00000000..09cec93f --- /dev/null +++ b/notebooks/transformations.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-15T09:57:50.791332Z", + "start_time": "2020-02-15T09:57:46.068701Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "!pip install napari\n", + "!pip install SimpleITK" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:54:54.920459Z", + "start_time": "2020-02-16T16:54:54.669509Z" + } + }, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "TEST=0 # TODO: replace this with environment flag" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-15T11:55:06.130717Z", + "start_time": "2020-02-15T11:54:22.119553Z" + } + }, + "outputs": [], + "source": [ + "from io import BytesIO\n", + "from zipfile import ZipFile\n", + "from urllib.request import urlopen\n", + "\n", + "resp = urlopen(\"http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip\")\n", + "zipfile = ZipFile(BytesIO(resp.read()))\n", + "\n", + "img_file = zipfile.extract(\"ExBox3/T1_brain.nii.gz\")\n", + "mask_file = zipfile.extract(\"ExBox3/T1_brain_seg.nii.gz\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:01.394975Z", + "start_time": "2020-02-16T16:55:00.893340Z" + } + }, + "outputs": [], + "source": [ + "import SimpleITK as sitk\n", + "import numpy as np\n", + "\n", + "# load image and mask\n", + "img_file = \"./ExBox3/T1_brain.nii.gz\"\n", + "mask_file = \"./ExBox3/T1_brain_seg.nii.gz\"\n", + "img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))\n", + "img = img.astype(np.float32)\n", + "mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))\n", + "mask = mask.astype(np.float32)\n", + "\n", + "assert mask.shape == img.shape\n", + "print(f\"Image shape {img.shape}\")\n", + "print(f\"Image shape {mask.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:04.255613Z", + "start_time": "2020-02-16T16:55:03.213336Z" + } + }, + "outputs": [], + "source": [ + "%gui qt\n", + "import napari\n", + "def view_batch(batch):\n", + " if not TEST:\n", + " viewer = napari.view_image(batch[\"data\"].cpu().numpy(), name=\"data\")\n", + " viewer.add_image(batch[\"mask\"].cpu().numpy(), name=\"mask\", opacity=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:54.082493Z", + "start_time": "2020-02-16T16:55:54.019599Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from rising.transforms import *\n", + "\n", + "batch = {\n", + " \"data\": torch.from_numpy(img).float()[None],\n", + " \"mask\": torch.from_numpy(mask).long()[None],\n", + "}\n", + "\n", + "def apply_transform(trafo, batch):\n", + " transformed = trafo(**batch)\n", + " print(f\"Transformed data shape: {transformed['data'].shape}\")\n", + " print(f\"Transformed mask shape: {transformed['mask'].shape}\")\n", + " print(f\"Transformed data min: {transformed['data'].min()}\")\n", + " print(f\"Transformed data max: {transformed['data'].max()}\")\n", + " print(f\"Transformed data mean: {transformed['data'].mean()}\")\n", + " return transformed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:06.109008Z", + "start_time": "2020-02-16T16:55:06.069336Z" + } + }, + "outputs": [], + "source": [ + "print(f\"Transformed data shape: {batch['data'].shape}\")\n", + "print(f\"Transformed mask shape: {batch['mask'].shape}\")\n", + "print(f\"Transformed data min: {batch['data'].min()}\")\n", + "print(f\"Transformed data max: {batch['data'].max()}\")\n", + "print(f\"Transformed data mean: {batch['data'].mean()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:57.391117Z", + "start_time": "2020-02-16T16:55:55.675294Z" + } + }, + "outputs": [], + "source": [ + "trafo = Scale(1.5, adjust_size=False)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T17:03:58.535489Z", + "start_time": "2020-02-16T17:03:57.964843Z" + } + }, + "outputs": [], + "source": [ + "trafo = Rotate(45, degree=True, adjust_size=False)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:00:26.032367Z", + "start_time": "2020-02-16T16:00:25.466391Z" + } + }, + "outputs": [], + "source": [ + "trafo = Translate(0.1, adjust_size=False)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/rising/transforms/__init__.py b/rising/transforms/__init__.py index 99b95a19..026fe35f 100644 --- a/rising/transforms/__init__.py +++ b/rising/transforms/__init__.py @@ -8,3 +8,4 @@ from rising.transforms.spatial import * from rising.transforms.utility import * from rising.transforms.tensor import * +from rising.transforms.affine import * diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index bf764e30..239d3873 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -1,16 +1,18 @@ from rising.transforms.abstract import BaseTransform from rising.transforms.functional.affine import affine_image_transform +from rising.utils import check_scalar from rising.utils.affine import AffineParamType, \ assemble_matrix_if_necessary, matrix_to_homogeneous, matrix_to_cartesian import torch -from typing import Sequence, Union +from typing import Sequence, Union, Iterable __all__ = [ 'Affine', 'StackedAffine', 'Rotate', 'Scale', - 'Translate' + 'Translate', + 'Resize' ] @@ -572,16 +574,105 @@ def assemble_matrix(self, **data) -> torch.Tensor: """ whole_trafo = None - for trafo in self.transforms: matrix = matrix_to_homogeneous(trafo.assemble_matrix(**data)) - if whole_trafo is None: whole_trafo = matrix else: whole_trafo = torch.bmm(whole_trafo, matrix) - return matrix_to_cartesian(whole_trafo) -# TODO: Add transforms around image center -# TODO: Add Resize Transform + +class Resize(Scale): + def __init__(self, + size: Union[int, Iterable], + keys: Sequence = ('data',), + grad: bool = False, + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + **kwargs): + """ + Class Performing a Resizing Affine Transformation on a given + sample dict. + The transformation will be applied to all the dict-entries specified + in :attr:`keys`. + + Parameters + ---------- + size : int, Iterable + the target size. If int, this will be repeated for all the + dimensions + keys: Sequence + keys which should be augmented + grad: bool + enable gradient computation inside transformation + interpolation_mode : str + interpolation mode to calculate output values + 'bilinear' | 'nearest'. Default: 'bilinear' + padding_mode : + padding mode for outside grid values + 'zeros' | 'border' | 'reflection'. Default: 'zeros' + align_corners : Geometrically, we consider the pixels of the input as + squares rather than points. If set to True, the extrema (-1 and 1) + are considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + **kwargs : + additional keyword arguments passed to the affine transform + + Note + ---- + The offsets for shifting back and to origin are calculated on the + entry matching the first item iin :attr:`keys` for each batch + + Note + ---- + The target size must be specified in x, y (,z) order and will be + converted to (D,) H, W order internally + + """ + super().__init__(output_size=size, + scale=None, + keys=keys, + grad=grad, + adjust_size=False, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + **kwargs) + + def assemble_matrix(self, **data) -> torch.Tensor: + """ + Handles the matrix assembly and calculates the scale factors for + resizing + + Parameters + ---------- + **data : + the data to be transformed. Will be used to determine batchsize, + dimensionality, dtype and device + + Returns + ------- + torch.Tensor + the (batched) transformation matrix + + """ + curr_img_size = data[self.keys[0]].shape[2:] + + was_scalar = check_scalar(self.output_size) + + if was_scalar: + self.output_size = [self.output_size] * len(curr_img_size) + + self.scale = [self.output_size[i] / curr_img_size[-i] + for i in range(len(curr_img_size))] + + matrix = super().assemble_matrix(**data) + + if was_scalar: + self.output_size = self.output_size[0] + + return matrix diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index db98401a..062626e3 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -1,6 +1,6 @@ import torch from rising.utils.affine import points_to_cartesian, matrix_to_homogeneous, \ - points_to_homogeneous, matrix_revert_coordinate_order + points_to_homogeneous, unit_box from rising.utils.checktype import check_scalar import warnings @@ -30,12 +30,8 @@ def affine_point_transform(point_batch: torch.Tensor, """ point_batch = points_to_homogeneous(point_batch) matrix_batch = matrix_to_homogeneous(matrix_batch) - - matrix_batch = matrix_revert_coordinate_order(matrix_batch) - transformed_points = torch.bmm(point_batch, matrix_batch.permute(0, 2, 1)) - return points_to_cartesian(transformed_points) @@ -120,8 +116,7 @@ def affine_image_transform(image_batch: torch.Tensor, missing_dims = len(image_batch.shape) - len(image_size) new_size = (*image_batch.shape[:missing_dims], *new_size) - matrix_batch = matrix_batch.to(device=image_batch.device, - dtype=image_batch.dtype) + matrix_batch = matrix_batch.to(image_batch) grid = torch.nn.functional.affine_grid(matrix_batch, size=new_size, align_corners=align_corners) @@ -132,7 +127,8 @@ def affine_image_transform(image_batch: torch.Tensor, align_corners=align_corners) -def _check_new_img_size(curr_img_size, matrix: torch.Tensor) -> torch.Tensor: +def _check_new_img_size(curr_img_size, matrix: torch.Tensor, + zero_border: bool = False) -> torch.Tensor: """ Calculates the image size so that the whole image content fits the image. The resulting size will be the maximum size of the batch, so that the @@ -145,49 +141,27 @@ def _check_new_img_size(curr_img_size, matrix: torch.Tensor) -> torch.Tensor: all image dimensions matrix : torch.Tensor a batch of affine matrices with shape N x NDIM x NDIM + 1 + zero_border : bool + whether or not to have a fixed image border at zero Returns ------- torch.Tensor the new image size - """ - n_dim = matrix.size(-1) - 1 - if check_scalar(curr_img_size): curr_img_size = [curr_img_size] * n_dim - - curr_img_size = [tmp - 1 for tmp in curr_img_size] - - if n_dim == 2: - possible_points = torch.tensor([[0., 0.], [0., curr_img_size[1]], - [curr_img_size[0], 0], curr_img_size], - dtype=matrix.dtype, - device=matrix.device) - elif n_dim == 3: - possible_points = torch.tensor( - [ - [0., 0., 0.], - [0., 0., curr_img_size[2]], - [0., curr_img_size[1], 0], - [0., curr_img_size[1], curr_img_size[2]], - [curr_img_size[0], 0., 0.], - [curr_img_size[0], 0., curr_img_size[2]], - [curr_img_size[0], curr_img_size[1], 0.], - curr_img_size - ], device=matrix.device, dtype=matrix.dtype - ) - - else: - raise ValueError('Invalid number of dimensions! Expected One of ' - '{2, 3}, but got %s' % str(n_dim)) + possible_points = unit_box(n_dim, torch.tensor(curr_img_size)).to(matrix) transformed_edges = affine_point_transform( possible_points[None].expand( - matrix.size(0), - *[-1 for _ in possible_points.shape]).clone(), + matrix.size(0), *[-1 for _ in possible_points.shape]).clone(), matrix) - return (transformed_edges.max(1)[0] - - transformed_edges.min(1)[0]).max(0)[0] + 1 + if zero_border: + substr = 0 + else: + substr = transformed_edges.min(1)[0] + + return (transformed_edges.max(1)[0] - substr).max(0)[0] diff --git a/rising/utils/affine.py b/rising/utils/affine.py index ee21a9ab..9b6da49a 100644 --- a/rising/utils/affine.py +++ b/rising/utils/affine.py @@ -1,8 +1,11 @@ import torch -from rising.utils.checktype import check_scalar +import itertools + from math import pi from typing import Union, Sequence +from rising.utils.checktype import check_scalar + AffineParamType = Union[int, float, Sequence, torch.Tensor] @@ -106,28 +109,6 @@ def points_to_cartesian(batch: torch.Tensor) -> torch.Tensor: return batch[..., :-1] / batch[..., -1, None] -def matrix_revert_coordinate_order(batch: torch.Tensor) -> torch.Tensor: - """ - Reverts the coordinate order of a matrix (e.g. from xyz to zyx). - - Parameters - ---------- - batch : torch.Tensor - the batched transformation matrices; Should be of shape - BATCHSIZE x NDIM x NDIM - - Returns - ------- - torch.Tensor - the matrix performing the same transformation on vectors with a - reversed coordinate order - - """ - batch[:, :-1, :] = batch[:, :-1, :].flip(1).clone() - batch[:, :-1, :-1] = batch[:, :-1, :-1].flip(2).clone() - return batch - - def get_batched_eye(batchsize: int, ndim: int, device: Union[torch.device, str] = None, dtype: Union[torch.dtype, str] = None) -> torch.Tensor: @@ -379,7 +360,6 @@ def _format_rotation(rotation: AffineParamType, if check_scalar(rotation): rotation = torch.ones(batchsize, num_rot_params, device=device, dtype=dtype) * rotation - elif not torch.is_tensor(rotation): rotation = torch.tensor(rotation, device=device, dtype=dtype) @@ -404,8 +384,7 @@ def _format_rotation(rotation: AffineParamType, return rotation # bring it to default size of (batchsize, num_rot_params) elif rotation.size() == (batchsize,): - rotation = rotation.view(batchsize, 1).expand(-1, - num_rot_params).clone() + rotation = rotation.view(batchsize, 1).expand(-1, num_rot_params).clone() elif rotation.size() == (num_rot_params,): rotation = rotation.view(1, num_rot_params).expand(batchsize, -1).clone() @@ -424,7 +403,6 @@ def _format_rotation(rotation: AffineParamType, whole_rot_matrix[:, 1, 1] = cos[0].clone() whole_rot_matrix[:, 0, 1] = (-sin[0]).clone() whole_rot_matrix[:, 1, 0] = sin[0].clone() - else: whole_rot_matrix[:, 0, 0] = (cos[:, 0] * cos[:, 1] * cos[:, 2] - sin[:, 0] * sin[:, 2]).clone() @@ -618,3 +596,26 @@ def assemble_matrix_if_necessary(batchsize: int, ndim: int, "Got %s but expected %s" % ( str(tuple(matrix.shape)), str((batchsize, ndim, ndim + 1)))) + + +def unit_box(n: int, scale: torch.Tensor = None) -> torch.Tensor: + """ + Create a sclaed version of a unit box + + Parameters + ---------- + n: int + number of dimensions + scale: Tensor + scaling of each dimension + + Returns + ------- + Tensor + scaled unit box + """ + box = torch.tensor( + [list(i) for i in itertools.product([0, 1], repeat=n)]) + if scale is not None: + box = box.to(scale) * scale[None] + return box diff --git a/tests/transforms/functional/test_affine.py b/tests/transforms/functional/test_affine.py index b5877678..8cf519b4 100644 --- a/tests/transforms/functional/test_affine.py +++ b/tests/transforms/functional/test_affine.py @@ -2,13 +2,11 @@ import torch from rising.transforms.functional.affine import _check_new_img_size, \ affine_point_transform, affine_image_transform -from rising.utils.affine import parametrize_matrix, matrix_to_homogeneous, matrix_to_cartesian, \ - matrix_revert_coordinate_order +from rising.utils.affine import parametrize_matrix, matrix_to_homogeneous, matrix_to_cartesian from rising.utils.checktype import check_scalar class AffineTestCase(unittest.TestCase): - def test_check_image_size(self): images = [torch.rand(11, 2, 3, 4, 5), torch.rand(11, 2, 3, 4), torch.rand(11, 2, 3, 3)] @@ -51,28 +49,35 @@ def test_check_image_size(self): batchsize=1, ndim=ndim, dtype=torch.float)) edge_pts = torch.tensor(edge_pts, dtype=torch.float) - edge_pts[edge_pts > 1] = edge_pts[edge_pts > 1] - 1 img = img.to(torch.float) - - new_edges = torch.bmm(edge_pts.unsqueeze(0), - matrix_revert_coordinate_order(affine.clone()).permute(0, 2, 1)) - - img_size = (new_edges.max(dim=1)[0] - new_edges.min(dim=1)[0])[0] - - fn_result = _check_new_img_size(size, - matrix_to_cartesian( - affine.expand(img.size(0), -1, -1).clone())) - - self.assertTrue(torch.allclose(img_size[:-1] + 1, - fn_result)) - - with self.assertRaises(ValueError): - _check_new_img_size([2, 3, 4, 5], torch.rand(11, 2, 2, 3, 4, 5)) + new_edges = torch.bmm(edge_pts.unsqueeze(0), affine.clone().permute(0, 2, 1)) + + img_size_zero_border = new_edges.max(dim=1)[0][0] + img_size_non_zero_border = (new_edges.max(dim=1)[0] + - new_edges.min(dim=1)[0])[0] + + fn_result_zero_border = _check_new_img_size( + size, + matrix_to_cartesian( + affine.expand(img.size(0), -1, -1).clone()), + zero_border=True, + ) + fn_result_non_zero_border = _check_new_img_size( + size, + matrix_to_cartesian( + affine.expand(img.size(0), -1, -1).clone()), + zero_border=False, + ) + + self.assertTrue(torch.allclose(img_size_zero_border[:-1], + fn_result_zero_border)) + self.assertTrue(torch.allclose(img_size_non_zero_border[:-1], + fn_result_non_zero_border)) def test_affine_point_transform(self): points = [ [[[0, 1], [1, 0]]], - [[[0, 0, 1]]] + [[[1, 1, 1]]], ] matrices = [ torch.tensor([[[1., 0.], [0., 5.]]]), @@ -84,8 +89,8 @@ def test_affine_point_transform(self): device='cpu') ] expected = [ - [[0, 1], [5, 0]], - [[0, 1, 0]] + [[0, 5], [1, 0]], + [[-1, 1, 1]] ] for input_pt, matrix, expected_pt in zip(points, matrices, expected): @@ -111,7 +116,7 @@ def test_affine_image_trafo(self): image_batch = torch.zeros(10, 3, 25, 25, dtype=torch.float, device='cpu') - target_sizes = [(121, 97), image_batch.shape[2:], (50, 50), (50, 50), + target_sizes = [(100, 125), image_batch.shape[2:], (50, 50), (50, 50), (45, 50), (45, 50)] for output_size in [None, 50, (45, 50)]: diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index 79406c14..2b6c7d04 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -1,6 +1,7 @@ import unittest -from rising.transforms.affine import Affine, StackedAffine, Translate, Rotate, \ - Scale +from rising.transforms.affine import ( + Affine, StackedAffine, Translate, Scale, Rotate) +# TODO: add resize transform import torch from copy import deepcopy from rising.utils.affine import matrix_to_cartesian, matrix_to_homogeneous @@ -13,7 +14,7 @@ def test_affine(self): device='cpu') matrix = matrix.expand(image_batch.size(0), -1, -1).clone() - target_sizes = [(121, 97), image_batch.shape[2:], (50, 50), (50, 50), + target_sizes = [(100, 125), image_batch.shape[2:], (50, 50), (50, 50), (45, 50), (45, 50)] for output_size in [None, 50, (45, 50)]: @@ -78,17 +79,25 @@ def test_stacked_transformation_assembly(self): self.assertTrue(torch.allclose(matrix, target_matrix)) def test_affine_subtypes(self): + sample = {'data': torch.rand(1, 3, 25, 30)} - sample = {'data': torch.rand(10, 3, 25, 25)} trafos = [ - Scale(5), - Rotate(45), - Translate(10) + Scale([2, 3], adjust_size=True), + # Resize([50, 90]), + Rotate([90], adjust_size=True, degree=True), ] - for trafo in trafos: - with self.subTest(trafo=trafo): - self.assertIsInstance(trafo(**sample)['data'], torch.Tensor) + expected_sizes = [ + (50, 90), + # (50, 90), + (30, 25), + ] + + for trafo, expected_size in zip(trafos, expected_sizes): + with self.subTest(trafo=trafo, exp_size=expected_size): + result = trafo(**sample)['data'] + self.assertIsInstance(result, torch.Tensor) + self.assertTupleEqual(expected_size, result.shape[-2:]) if __name__ == '__main__': diff --git a/tests/utils/test_affine.py b/tests/utils/test_affine.py index 8867524c..f215a47e 100644 --- a/tests/utils/test_affine.py +++ b/tests/utils/test_affine.py @@ -1,8 +1,9 @@ import unittest from rising.utils.affine import points_to_homogeneous, matrix_to_homogeneous, \ - matrix_to_cartesian, points_to_cartesian, matrix_revert_coordinate_order, \ + matrix_to_cartesian, points_to_cartesian, \ get_batched_eye, _format_scale, _format_translation, deg_to_rad, \ - _format_rotation, parametrize_matrix, assemble_matrix_if_necessary + _format_rotation, parametrize_matrix, assemble_matrix_if_necessary, \ + unit_box import torch import math @@ -128,24 +129,6 @@ def test_matrix_to_cartesian(self): self.assertTrue(torch.allclose(matrix_to_cartesian(inp, keep_square=keep_square), exp)) keep_square = not keep_square - def test_matrix_coordinate_order(self): - inputs = [ - torch.tensor([[[1, 2, 3], - [4, 5, 6], - [7, 8, 9]]]) - ] - - expectations = [ - torch.tensor([[[5, 4, 6], - [2, 1, 3], - [7, 8, 9]]]) - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - self.assertTrue(torch.allclose(matrix_revert_coordinate_order(inp), exp)) - # self.assertTrue(torch.allclose(inp, matrix_revert_coordinate_order(exp))) - def test_batched_eye(self): for dtype in [torch.float, torch.long]: for ndim in range(10): @@ -363,6 +346,36 @@ def test_necessary_assembly(self): degree=False, dtype=torch.float, device='cpu', batchsize=1, ndim=2) + def test_unit_box_2d(self): + curr_img_size = torch.tensor([2, 3]) + box = torch.tensor([[0., 0.], [0., curr_img_size[1]], + [curr_img_size[0], 0], curr_img_size]) + created_box = unit_box(2, curr_img_size).to(box) + self.compare_points_unordered(box, created_box) + + def compare_points_unordered(self, points0: torch.Tensor, points1: torch.Tensor): + self.assertEqual(tuple(points0.shape), tuple(points1.shape)) + for point in points0: + comp = point[None] == points1 + comp = comp.sum(dim=1) == comp.shape[1] + self.assertTrue(comp.any()) + + def test_unit_box_3d(self): + curr_img_size = torch.tensor([2, 3, 4]) + box = torch.tensor( + [ + [0., 0., 0.], + [0., 0., curr_img_size[2]], + [0., curr_img_size[1], 0], + [0., curr_img_size[1], curr_img_size[2]], + [curr_img_size[0], 0., 0.], + [curr_img_size[0], 0., curr_img_size[2]], + [curr_img_size[0], curr_img_size[1], 0.], + curr_img_size + ]) + created_box = unit_box(3, curr_img_size).to(box) + self.compare_points_unordered(box, created_box) + if __name__ == '__main__': unittest.main() From 147eb9de05554ff0704b5ace33fde1d15c72309d Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Sun, 16 Feb 2020 17:39:38 +0000 Subject: [PATCH 02/20] autopep8 fix --- tests/utils/test_affine.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/utils/test_affine.py b/tests/utils/test_affine.py index f215a47e..e74cb870 100644 --- a/tests/utils/test_affine.py +++ b/tests/utils/test_affine.py @@ -349,7 +349,7 @@ def test_necessary_assembly(self): def test_unit_box_2d(self): curr_img_size = torch.tensor([2, 3]) box = torch.tensor([[0., 0.], [0., curr_img_size[1]], - [curr_img_size[0], 0], curr_img_size]) + [curr_img_size[0], 0], curr_img_size]) created_box = unit_box(2, curr_img_size).to(box) self.compare_points_unordered(box, created_box) @@ -363,16 +363,16 @@ def compare_points_unordered(self, points0: torch.Tensor, points1: torch.Tensor) def test_unit_box_3d(self): curr_img_size = torch.tensor([2, 3, 4]) box = torch.tensor( - [ - [0., 0., 0.], - [0., 0., curr_img_size[2]], - [0., curr_img_size[1], 0], - [0., curr_img_size[1], curr_img_size[2]], - [curr_img_size[0], 0., 0.], - [curr_img_size[0], 0., curr_img_size[2]], - [curr_img_size[0], curr_img_size[1], 0.], - curr_img_size - ]) + [ + [0., 0., 0.], + [0., 0., curr_img_size[2]], + [0., curr_img_size[1], 0], + [0., curr_img_size[1], curr_img_size[2]], + [curr_img_size[0], 0., 0.], + [curr_img_size[0], 0., curr_img_size[2]], + [curr_img_size[0], curr_img_size[1], 0.], + curr_img_size + ]) created_box = unit_box(3, curr_img_size).to(box) self.compare_points_unordered(box, created_box) From 2cb2d2d0575a302db0a34684d1e1780acdca650a Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Tue, 18 Feb 2020 18:02:49 +0100 Subject: [PATCH 03/20] refactoring --- CHANGELOG.md | 3 + rising/transforms/affine.py | 3 +- rising/transforms/compose.py | 33 +++++++++- rising/transforms/functional/utility.py | 69 ++++++++++++++++++++- rising/transforms/utility.py | 50 ++++++++++++++- tests/transforms/functional/test_utility.py | 51 +++++++++++++++ tests/transforms/test_affine.py | 7 ++- tests/transforms/test_compose.py | 37 ++++++++++- tests/transforms/test_utility_transforms.py | 51 +++++++++++++++ 9 files changed, 293 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e65f7887..613ae3ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ | Date | Commit | Short Description | Breaking Changes? | | ---- | --------- | ----------------- | ----------------- | +| 2020-02-17 | f69f36918a03bac835ff6bbe9b5cc7045bcf5d92 | Add Transforms to pop and filter keys | No | +| 2020-02-17 | 5053ca76f7fc43ff813748d040c69d490ba0210c | Use ModuleList in Compose | No | +| 2020-02-17 | 26d7c9432b90247f8b9e0ab3e6881e88fb7749d4 | Add Resizing Transform | No | | 2020-01-03 | 6b9a7b2fdc7d0b894c0dfcfd94237845fe8b8672 | Affine Trafos | No| | 2019-12-24 | 6b90197e89dedd7659073bf72037390231a1c278 | Use shared memory for progressive resizing | No | | 2019-12-17 | 0b881f8e0ce85f380ecf458080c2a3f5cb8c3080 | User-Controllable call dispatch within the compose class | No | diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index 239d3873..d2840c8b 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -3,6 +3,7 @@ from rising.utils import check_scalar from rising.utils.affine import AffineParamType, \ assemble_matrix_if_necessary, matrix_to_homogeneous, matrix_to_cartesian +from rising.utils.checktype import check_scalar import torch from typing import Sequence, Union, Iterable @@ -12,7 +13,7 @@ 'Rotate', 'Scale', 'Translate', - 'Resize' + 'Resize', ] diff --git a/rising/transforms/compose.py b/rising/transforms/compose.py index 7cc520bb..2f601669 100644 --- a/rising/transforms/compose.py +++ b/rising/transforms/compose.py @@ -1,6 +1,7 @@ from typing import Sequence, Union, Callable, Any, Mapping from rising.utils import check_scalar from rising.transforms import AbstractTransform, RandomProcess +import torch __all__ = ["Compose", "DropoutCompose"] @@ -25,6 +26,31 @@ def dict_call(batch: dict, transform: Callable) -> Any: return transform(**batch) +class _TransformWrapper(torch.nn.Module): + def __init__(self, trafo: Callable): + """ + Helper Class to wrap all non-module transforms into modules to use the + torch.nn.ModuleList as container for the transforms. This enables + forwarding of all model specific calls as ``.to()`` to all transforms + + Parameters + ---------- + trafo : Callable + the actual transform, which will be wrapped by this class. + Since this transform is no subclass of ``torch.nn.Module``, + its internal state won't be affected by module specific calls + """ + super().__init__() + + self.trafo = trafo + + def forward(self, *args, **kwargs) -> Any: + """ + Forwards calls to this wrapper to the internal transform + """ + return self.trafo(*args, **kwargs) + + class Compose(AbstractTransform): def __init__(self, *transforms, transform_call: Callable[[Any, Callable], Any] = dict_call): @@ -42,7 +68,12 @@ def __init__(self, *transforms, super().__init__(grad=True) if isinstance(transforms[0], Sequence): transforms = transforms[0] - self.transforms = transforms + + for idx, trafo in enumerate(transforms): + if not isinstance(trafo, torch.nn.Module): + transforms[idx] = _TransformWrapper(trafo) + + self.transforms = torch.nn.ModuleList(transforms) self.transform_call = transform_call def forward(self, *seq_like, **map_like) -> Union[Sequence, Mapping]: diff --git a/rising/transforms/functional/utility.py b/rising/transforms/functional/utility.py index 9bd82a72..468e56a5 100644 --- a/rising/transforms/functional/utility.py +++ b/rising/transforms/functional/utility.py @@ -1,8 +1,8 @@ -from typing import Sequence, List, Tuple +from typing import Sequence, List, Tuple, Union, Callable from torch import Tensor import torch -__all__ = ["box_to_seg", "seg_to_box", "instance_to_semantic"] +__all__ = ["box_to_seg", "seg_to_box", "instance_to_semantic", "pop_keys", "filter_keys"] def box_to_seg(boxes: Sequence[Sequence[int]], shape: Sequence[int] = None, @@ -100,3 +100,68 @@ def instance_to_semantic(instance: Tensor, cls: Sequence[int]) -> Tensor: for idx, c in enumerate(cls, 1): seg[instance == idx] = c return seg + + +def pop_keys(data: dict, keys: Union[Callable, Sequence], return_popped=False) -> Union[dict, Tuple[dict, dict]]: + """ + Pops keys from a given data dict + + Parameters + ---------- + data : dict + the dictionary to pop the keys from + keys : Callable or Sequence of Strings + if callable it must return a boolean for each key indicating whether it should be popped from the dict. + if sequence of strings, the strings shall be the keys to be popped + return_popped : bool + whether to also return the popped values (default: False) + + Returns + ------- + dict + the data without the popped values + dict, optional + the popped values; only if ``return_popped`` is True + + """ + if callable(keys): + keys = [k for k in data.keys() if keys(k)] + + popped = {} + + for k in keys: + popped[k] = data.pop(k) + + if return_popped: + return data, popped + else: + return data + + +def filter_keys(data: dict, keys: Union[Callable, Sequence], return_popped=False) -> Union[dict, Tuple[dict, dict]]: + """ + Filters keys from a given data dict + + Parameters + ---------- + data : dict + the dictionary to pop the keys from + keys : Callable or Sequence of Strings + if callable it must return a boolean for each key indicating whether it should be retained in the dict. + if sequence of strings, the strings shall be the keys to be retained + return_popped : bool + whether to also return the popped values (default: False) + + Returns + ------- + dict + the data without the popped values + dict, optional + the popped values; only if ``return_popped`` is True + + """ + if callable(keys): + keys = [k for k in data.keys() if keys(k)] + + keys_to_pop = [k for k in data.keys() if k not in keys] + return pop_keys(data=data, keys=keys_to_pop, return_popped=return_popped) diff --git a/rising/transforms/utility.py b/rising/transforms/utility.py index 8ab2a869..18fc1585 100644 --- a/rising/transforms/utility.py +++ b/rising/transforms/utility.py @@ -1,10 +1,10 @@ -from typing import Sequence, Mapping, Hashable, Union +from typing import Sequence, Mapping, Hashable, Union, Callable, Tuple import torch from rising.transforms.abstract import AbstractTransform -from rising.transforms.functional.utility import seg_to_box, box_to_seg, instance_to_semantic +from rising.transforms.functional.utility import seg_to_box, box_to_seg, instance_to_semantic, pop_keys, filter_keys -__all__ = ["DoNothing", "SegToBox", "BoxToSeg", "InstanceToSemantic"] +__all__ = ["DoNothing", "SegToBox", "BoxToSeg", "InstanceToSemantic", "PopKeys", "FilterKeys"] class DoNothing(AbstractTransform): @@ -122,3 +122,47 @@ def forward(self, **data) -> dict: data[target] = torch.cat([instance_to_semantic(data, mapping) for data, mapping in zip(data[source].split(1), data[self.cls_key])]) return data + + +class PopKeys(AbstractTransform): + def __init__(self, keys: Union[Callable, Sequence], return_popped: bool = False): + """ + Pops keys from a given data dict + + Parameters + ---------- + keys : Callable or Sequence of Strings + if callable it must return a boolean for each key indicating whether it should be popped from the dict. + if sequence of strings, the strings shall be the keys to be popped + return_popped : bool + whether to also return the popped values (default: False) + + """ + super().__init__(grad=False) + self.keys = keys + self.return_popped = return_popped + + def forward(self, **data) -> Union[dict, Tuple[dict, dict]]: + return pop_keys(data=data, keys=self.keys, return_popped=self.return_popped) + + +class FilterKeys(AbstractTransform): + def __init__(self, keys: Union[Callable, Sequence], return_popped: bool = False): + """ + Filters keys from a given data dict + + Parameters + ---------- + keys : Callable or Sequence of Strings + if callable it must return a boolean for each key indicating whether it should be retained in the dict. + if sequence of strings, the strings shall be the keys to be retained + return_popped : bool + whether to also return the popped values (default: False) + + """ + super().__init__(grad=False) + self.keys = keys + self.return_popped = return_popped + + def forward(self, **data) -> Union[dict, Tuple[dict, dict]]: + return filter_keys(data=data, keys=self.keys, return_popped=self.return_popped) diff --git a/tests/transforms/functional/test_utility.py b/tests/transforms/functional/test_utility.py index 550399b4..c83f44d0 100644 --- a/tests/transforms/functional/test_utility.py +++ b/tests/transforms/functional/test_utility.py @@ -1,5 +1,6 @@ import unittest import torch +from copy import deepcopy from rising.transforms.functional import * @@ -40,6 +41,56 @@ def test_instance_to_semantic(self): expected[5:8, 5:8, 1:3] = 1 self.assertTrue((semantic == expected).all()) + def test_pop_keys(self): + data = {str(idx): idx for idx in range(10)} + keys_to_pop_list = [str(idx) for idx in range(0, 10, 2)] + + def keys_to_pop_fn(key): + return key in [str(idx) for idx in range(0, 10, 2)] + + for return_pop in [True, False]: + for _pop_keys in [keys_to_pop_list, keys_to_pop_fn]: + with self.subTest(return_pop=return_pop, pop_keys=_pop_keys): + if isinstance(_pop_keys, list): + __pop_keys = deepcopy(_pop_keys) + else: + __pop_keys = _pop_keys + result = pop_keys(data=deepcopy(data), keys=__pop_keys, + return_popped=return_pop) + + if return_pop: + result, popped = result + for k in popped.keys(): + self.assertIn(k, keys_to_pop_list) + + for k in result.keys(): + self.assertNotIn(k, keys_to_pop_list) + + def test_filter_keys(self): + data = {str(idx): idx for idx in range(10)} + keys_to_filter_list = [str(idx) for idx in range(0, 10, 2)] + + def keys_to_filter_fn(key): + return key in [str(idx) for idx in range(0, 10, 2)] + + for return_pop in [True, False]: + for _filter_keys in [keys_to_filter_list, keys_to_filter_fn]: + with self.subTest(return_pop=return_pop, filter_keys=_filter_keys): + if isinstance(_filter_keys, list): + __filter_keys = deepcopy(_filter_keys) + else: + __filter_keys = _filter_keys + result = filter_keys(data=deepcopy(data), keys=__filter_keys, + return_popped=return_pop) + + if return_pop: + result, popped = result + for k in popped.keys(): + self.assertNotIn(k, keys_to_filter_list) + + for k in result.keys(): + self.assertIn(k, keys_to_filter_list) + if __name__ == '__main__': unittest.main() diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index 2b6c7d04..91d29eda 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -1,7 +1,6 @@ import unittest -from rising.transforms.affine import ( - Affine, StackedAffine, Translate, Scale, Rotate) -# TODO: add resize transform +from rising.transforms.affine import Affine, StackedAffine, Translate, Rotate, \ + Scale, Resize import torch from copy import deepcopy from rising.utils.affine import matrix_to_cartesian, matrix_to_homogeneous @@ -99,6 +98,8 @@ def test_affine_subtypes(self): self.assertIsInstance(result, torch.Tensor) self.assertTupleEqual(expected_size, result.shape[-2:]) + self.assertTupleEqual((5, 4), trafos[-1](**sample)['data'].shape[2:]) + if __name__ == '__main__': unittest.main() diff --git a/tests/transforms/test_compose.py b/tests/transforms/test_compose.py index 451d6c96..e88af523 100644 --- a/tests/transforms/test_compose.py +++ b/tests/transforms/test_compose.py @@ -2,7 +2,8 @@ import torch from rising.transforms.spatial import Mirror -from rising.transforms.compose import Compose, DropoutCompose +from rising.transforms.compose import Compose, DropoutCompose, \ + AbstractTransform, _TransformWrapper class TestCompose(unittest.TestCase): @@ -47,6 +48,40 @@ def test_dropout_compose_error(self): with self.assertRaises(TypeError): compose = DropoutCompose(self.transforms, dropout=[1.0]) + def test_device_dtype_change(self): + class DummyTrafo(AbstractTransform): + def __init__(self, a): + super().__init__(False) + self.register_buffer('tmp', a) + + def __call__(self, *args, **kwargs): + return self.tmp + + trafo_a = DummyTrafo(torch.tensor([1.], dtype=torch.float32)) + trafo_a = trafo_a.to(torch.float32) + trafo_b = DummyTrafo(torch.tensor([2.], dtype=torch.float32)) + trafo_b = trafo_b.to(torch.float32) + self.assertEquals(trafo_a.tmp.dtype, torch.float32) + self.assertEquals(trafo_b.tmp.dtype, torch.float32) + compose = Compose(trafo_a, trafo_b) + compose = compose.to(torch.float64) + + self.assertEquals(compose.transforms[0].tmp.dtype, torch.float64) + + def test_wrapping_non_module_trafos(self): + class DummyTrafo: + def __init__(self): + self.a = 5 + + def __call__(self, *args, **kwargs): + return 5 + + dummy_trafo = DummyTrafo() + + compose = Compose([dummy_trafo]) + self.assertIsInstance(compose.transforms[0], _TransformWrapper) + self.assertIsInstance(compose.transforms[0].trafo, DummyTrafo) + if __name__ == '__main__': unittest.main() diff --git a/tests/transforms/test_utility_transforms.py b/tests/transforms/test_utility_transforms.py index 1bafcfaa..dc80ad85 100644 --- a/tests/transforms/test_utility_transforms.py +++ b/tests/transforms/test_utility_transforms.py @@ -1,5 +1,6 @@ import unittest import torch +from copy import deepcopy from rising.transforms.utility import * @@ -43,6 +44,56 @@ def test_instance_to_semantic_transform(self): expected[5:8, 5:8, 1:3] = 1 self.assertTrue((semantic == expected).all()) + def test_pop_keys(self): + data = {str(idx): idx for idx in range(10)} + keys_to_pop_list = [str(idx) for idx in range(0, 10, 2)] + + def keys_to_pop_fn(key): + return key in [str(idx) for idx in range(0, 10, 2)] + + for return_pop in [True, False]: + for _pop_keys in [keys_to_pop_list, keys_to_pop_fn]: + with self.subTest(return_pop=return_pop, pop_keys=_pop_keys): + if isinstance(_pop_keys, list): + __pop_keys = deepcopy(_pop_keys) + else: + __pop_keys = _pop_keys + result = PopKeys(keys=__pop_keys, + return_popped=return_pop)(**deepcopy(data)) + + if return_pop: + result, popped = result + for k in popped.keys(): + self.assertIn(k, keys_to_pop_list) + + for k in result.keys(): + self.assertNotIn(k, keys_to_pop_list) + + def test_filter_keys(self): + data = {str(idx): idx for idx in range(10)} + keys_to_filter_list = [str(idx) for idx in range(0, 10, 2)] + + def keys_to_filter_fn(key): + return key in [str(idx) for idx in range(0, 10, 2)] + + for return_pop in [True, False]: + for _filter_keys in [keys_to_filter_list, keys_to_filter_fn]: + with self.subTest(return_pop=return_pop, filter_keys=_filter_keys): + if isinstance(_filter_keys, list): + __filter_keys = deepcopy(_filter_keys) + else: + __filter_keys = _filter_keys + result = FilterKeys(keys=__filter_keys, + return_popped=return_pop)(**deepcopy(data)) + + if return_pop: + result, popped = result + for k in popped.keys(): + self.assertNotIn(k, keys_to_filter_list) + + for k in result.keys(): + self.assertIn(k, keys_to_filter_list) + if __name__ == '__main__': unittest.main() From 5c9f6ebdfc3e3225dac4901692e72c09ffac00e1 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Tue, 18 Feb 2020 18:29:21 +0100 Subject: [PATCH 04/20] rename and move functions --- rising/transforms/affine.py | 7 +- rising/transforms/functional/affine.py | 452 ++++++++++++++++++++- rising/utils/affine.py | 446 -------------------- tests/transforms/functional/test_affine.py | 198 ++++++++- tests/transforms/test_affine.py | 4 - tests/transforms/test_compose.py | 7 +- tests/utils/test_affine.py | 195 +-------- 7 files changed, 653 insertions(+), 656 deletions(-) diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index d2840c8b..666526eb 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -1,8 +1,7 @@ from rising.transforms.abstract import BaseTransform -from rising.transforms.functional.affine import affine_image_transform -from rising.utils import check_scalar -from rising.utils.affine import AffineParamType, \ - assemble_matrix_if_necessary, matrix_to_homogeneous, matrix_to_cartesian +from rising.transforms.functional.affine import affine_image_transform, \ + AffineParamType, assemble_matrix_if_necessary +from rising.utils.affine import matrix_to_homogeneous, matrix_to_cartesian from rising.utils.checktype import check_scalar import torch from typing import Sequence, Union, Iterable diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index 062626e3..cc2fd7ca 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -1,8 +1,11 @@ import torch +import warnings +from typing import Union, Sequence + from rising.utils.affine import points_to_cartesian, matrix_to_homogeneous, \ - points_to_homogeneous, unit_box + points_to_homogeneous, unit_box, get_batched_eye, deg_to_rad, matrix_to_cartesian from rising.utils.checktype import check_scalar -import warnings + __all__ = [ 'affine_image_transform', @@ -10,6 +13,451 @@ ] +AffineParamType = Union[int, float, Sequence, torch.Tensor] + + +def create_scale(scale: AffineParamType, + batchsize: int, ndim: int, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None) -> torch.Tensor: + """ + Formats the given scale parameters to a homogeneous transformation matrix + + Parameters + ---------- + scale : torch.Tensor, int, float + the scale factor(s). Supported are: + * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a single parameter per sample (as a 1d tensor), which will be + replicated for all dimensions + * a single parameter per dimension (either as 1d tensor or as + 2d transformation matrix), which will be replicated for all + batch samples + None will be treated as a scaling factor of 1 + batchsize : int + the number of samples per batch + ndim : int + the dimensionality of the transform + device : torch.device, str, optional + the device to put the resulting tensor to. Defaults to the default + device + dtype : torch.dtype, str, optional + the dtype of the resulting trensor. Defaults to the default dtype + + Returns + ------- + torch.Tensor + the homogeneous transformation matrix + + """ + + if scale is None: + scale = 1 + + if check_scalar(scale): + + scale = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, + dtype=dtype) * scale + + elif not torch.is_tensor(scale): + scale = torch.tensor(scale, dtype=dtype, device=device) + + # scale must be tensor by now + scale = scale.to(device=device, dtype=dtype) + + # scale is already batched matrix + if scale.size() == (batchsize, ndim, ndim) or scale.size() == (batchsize, ndim, ndim + 1): + return matrix_to_homogeneous(scale) + + # scale is batched matrix with same element for each dimension or just + # not diagonalized + if scale.size() == (batchsize, ndim) or scale.size() == (batchsize,): + new_scale = get_batched_eye(batchsize=batchsize, ndim=ndim, + device=device, dtype=dtype) + + return matrix_to_homogeneous(new_scale * scale.view(batchsize, -1, 1)) + + # scale contains a non-diagonalized form (will be repeated for each batch + # item) + elif scale.size() == (ndim,): + return matrix_to_homogeneous( + torch.diag(scale).view(1, ndim, ndim).expand(batchsize, + -1, -1).clone()) + + # scale contains a diagonalized but not batched matrix + # (will be repeated for each batch item) + elif scale.size() == (ndim, ndim): + return matrix_to_homogeneous( + scale.view(1, ndim, ndim).expand(batchsize, -1, -1).clone()) + + raise ValueError("Unknown shape for scale matrix: %s" + % str(tuple(scale.size()))) + + +def create_translation(offset: AffineParamType, + batchsize: int, ndim: int, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None + ) -> torch.Tensor: + """ + Formats the given translation parameters to a homogeneous transformation + matrix + + Parameters + ---------- + offset : torch.Tensor, int, float + the translation offset(s). Supported are: + * a full homogeneous transformation matrix of shape + (BATCHSIZE x NDIM+1 x NDIM+1) + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a single parameter per sample (as a 1d tensor), which will be + replicated for all dimensions + * a single parameter per dimension (either as 1d tensor or as + 2d transformation matrix), which will be replicated for all + batch samples + None will be treated as a translation offset of 0 + batchsize : int + the number of samples per batch + ndim : int + the dimensionality of the transform + device : torch.device, str, optional + the device to put the resulting tensor to. Defaults to the default + device + dtype : torch.dtype, str, optional + the dtype of the resulting trensor. Defaults to the default dtype + + Returns + ------- + torch.Tensor + the homogeneous transformation matrix + + """ + if offset is None: + offset = 0 + + if check_scalar(offset): + offset = torch.tensor([offset] * ndim, device=device, dtype=dtype) + + elif not torch.is_tensor(offset): + offset = torch.tensor(offset, device=device, dtype=dtype) + + # assumes offset to be tensor from now on + offset = offset.to(device=device, dtype=dtype) + + # translation matrix already built + if offset.size() == (batchsize, ndim + 1, ndim + 1): + return offset + elif offset.size() == (batchsize, ndim, ndim + 1): + return matrix_to_homogeneous(offset) + + # not completely built so far -> bring in shape (batchsize, ndim) + if offset.size() == (batchsize,): + offset = offset.view(-1, 1).expand(-1, ndim).clone() + elif offset.size() == (ndim,): + offset = offset.view(1, -1).expand(batchsize, -1).clone() + elif not offset.size() == (batchsize, ndim): + raise ValueError("Unknown shape for offsets: %s" + % str(tuple(offset.shape))) + + # directly build homogeneous form -> use dim+1 + whole_translation_matrix = get_batched_eye(batchsize=batchsize, + ndim=ndim + 1, device=device, + dtype=dtype) + + whole_translation_matrix[:, :-1, -1] = offset.clone() + return whole_translation_matrix + + +def create_rotation(rotation: AffineParamType, + batchsize: int, ndim: int, + degree: bool = False, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None) -> torch.Tensor: + """ + Formats the given scale parameters to a homogeneous transformation matrix + + Parameters + ---------- + rotation : torch.Tensor, int, float + the rotation factor(s). Supported are: + * a full transformation matrix of shape (BATCHSIZE x NDIM(+1) x NDIM(+1)) + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a single parameter per sample (as a 1d tensor), which will be + replicated for all dimensions + * a single parameter per dimension (either as 1d tensor or as + 2d transformation matrix), which will be replicated for all + batch samples + None will be treated as a rotation factor of 0 + batchsize : int + the number of samples per batch + ndim : int + the dimensionality of the transform + degree : bool + whether the given rotation(s) are in degrees. + Only valid for rotation parameters, which aren't passed as full + transformation matrix. + device : torch.device, str, optional + the device to put the resulting tensor to. Defaults to the default + device + dtype : torch.dtype, str, optional + the dtype of the resulting trensor. Defaults to the default dtype + + Returns + ------- + torch.Tensor + the homogeneous transformation matrix + + """ + if rotation is None: + rotation = 0 + + num_rot_params = 1 if ndim == 2 else ndim + + if check_scalar(rotation): + rotation = torch.ones(batchsize, num_rot_params, + device=device, dtype=dtype) * rotation + elif not torch.is_tensor(rotation): + rotation = torch.tensor(rotation, device=device, dtype=dtype) + + # assumes rotation to be tensor by now + rotation = rotation.to(device=device, dtype=dtype) + + # already complete + if rotation.size() == (batchsize, ndim, ndim) or rotation.size() == (batchsize, ndim, ndim + 1): + return matrix_to_homogeneous(rotation) + elif rotation.size() == (batchsize, ndim + 1, ndim + 1): + return rotation + + if degree: + rotation = deg_to_rad(rotation) + + # repeat along batch dimension + if rotation.size() == (ndim, ndim) or rotation.size() == (ndim + 1, ndim + 1): + rotation = rotation[None].expand(batchsize, -1, -1).clone() + if rotation.size(-1) == ndim: + rotation = matrix_to_homogeneous(rotation) + + return rotation + # bring it to default size of (batchsize, num_rot_params) + elif rotation.size() == (batchsize,): + rotation = rotation.view(batchsize, 1).expand(-1, num_rot_params).clone() + elif rotation.size() == (num_rot_params,): + rotation = rotation.view(1, num_rot_params).expand(batchsize, + -1).clone() + elif rotation.size() != (batchsize, num_rot_params): + raise ValueError("Invalid shape for rotation parameters: %s" + % (str(tuple(rotation.size())))) + + sin, cos = rotation.sin(), rotation.cos() + + whole_rot_matrix = get_batched_eye(batchsize=batchsize, ndim=ndim, + device=device, dtype=dtype) + + # assemble the actual matrix + if num_rot_params == 1: + whole_rot_matrix[:, 0, 0] = cos[0].clone() + whole_rot_matrix[:, 1, 1] = cos[0].clone() + whole_rot_matrix[:, 0, 1] = (-sin[0]).clone() + whole_rot_matrix[:, 1, 0] = sin[0].clone() + else: + whole_rot_matrix[:, 0, 0] = (cos[:, 0] * cos[:, 1] * cos[:, 2] + - sin[:, 0] * sin[:, 2]).clone() + whole_rot_matrix[:, 0, 1] = (-cos[:, 0] * cos[:, 1] * sin[:, 2] + - sin[:, 0] * cos[:, 2]).clone() + whole_rot_matrix[:, 0, 2] = (cos[:, 0] * sin[:, 1]).clone() + whole_rot_matrix[:, 1, 0] = (sin[:, 0] * cos[:, 1] * cos[:, 2] + + cos[:, 0] * sin[:, 2]).clone() + whole_rot_matrix[:, 1, 1] = (-sin[:, 0] * cos[:, 1] * sin[:, 2] + + cos[:, 0] * cos[:, 2]).clone() + whole_rot_matrix[:, 2, 0] = (-sin[:, 1] * cos[:, 2]).clone() + whole_rot_matrix[:, 2, 1] = (-sin[:, 1] * sin[:, 2]).clone() + whole_rot_matrix[:, 2, 2] = (cos[:, 1]).clone() + + return matrix_to_homogeneous(whole_rot_matrix) + + +def parametrize_matrix(scale: AffineParamType, + rotation: AffineParamType, + translation: AffineParamType, + batchsize: int, ndim: int, + degree: bool = False, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None) -> torch.Tensor: + """ + Formats the given scale parameters to a homogeneous transformation matrix + + Parameters + ---------- + scale : torch.Tensor, int, float + the scale factor(s). Supported are: + * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a single parameter per sample (as a 1d tensor), which will be + replicated for all dimensions + * a single parameter per dimension (either as 1d tensor or as + 2d transformation matrix), which will be replicated for all + batch samples + None will be treated as a scaling factor of 1 + rotation : torch.Tensor, int, float + the rotation factor(s). Supported are: + * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a single parameter per sample (as a 1d tensor), which will be + replicated for all dimensions + * a single parameter per dimension (either as 1d tensor or as + 2d transformation matrix), which will be replicated for all + batch samples + None will be treated as a rotation factor of 1 + translation : torch.Tensor, int, float + the translation offset(s). Supported are: + * a full homogeneous transformation matrix of shape + (BATCHSIZE x NDIM+1 x NDIM+1) + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a single parameter per sample (as a 1d tensor), which will be + replicated for all dimensions + * a single parameter per dimension (either as 1d tensor or as + 2d transformation matrix), which will be replicated for all + batch samples + None will be treated as a translation offset of 0 + batchsize : int + the number of samples per batch + ndim : int + the dimensionality of the transform + degree : bool + whether the given rotation(s) are in degrees. + Only valid for rotation parameters, which aren't passed as full + transformation matrix. + device : torch.device, str, optional + the device to put the resulting tensor to. Defaults to the default + device + dtype : torch.dtype, str, optional + the dtype of the resulting trensor. Defaults to the default dtype + + Returns + ------- + torch.Tensor + the transformation matrix (of shape (BATCHSIZE x NDIM x NDIM+1) + + """ + scale = create_scale(scale, batchsize=batchsize, ndim=ndim, + device=device, dtype=dtype) + rotation = create_rotation(rotation, batchsize=batchsize, ndim=ndim, + degree=degree, device=device, dtype=dtype) + + translation = create_translation(translation, batchsize=batchsize, + ndim=ndim, device=device, dtype=dtype) + + return torch.bmm(torch.bmm(scale, rotation), translation)[:, :-1] + + +def assemble_matrix_if_necessary(batchsize: int, ndim: int, + scale: AffineParamType, + rotation: AffineParamType, + translation: AffineParamType, + matrix: torch.Tensor, + degree: bool, + device: Union[torch.device, str], + dtype: Union[torch.dtype, str] + ) -> torch.Tensor: + """ + Assembles a matrix, if the matrix is not already given + + Parameters + ---------- + batchsize : int + number of samples per batch + ndim : int + the image dimensionality + scale : torch.Tensor, int, float + the scale factor(s). Supported are: + * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a single parameter per sample (as a 1d tensor), which will be + replicated for all dimensions + * a single parameter per dimension (either as 1d tensor or as + 2d transformation matrix), which will be replicated for all + batch samples + None will be treated as a scaling factor of 1 + rotation : torch.Tensor, int, float + the rotation factor(s). Supported are: + * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a single parameter per sample (as a 1d tensor), which will be + replicated for all dimensions + * a single parameter per dimension (either as 1d tensor or as + 2d transformation matrix), which will be replicated for all + batch samples + None will be treated as a rotation factor of 1 + translation : torch.Tensor, int, float + the translation offset(s). Supported are: + * a full homogeneous transformation matrix of shape + (BATCHSIZE x NDIM+1 x NDIM+1) + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a single parameter per sample (as a 1d tensor), which will be + replicated for all dimensions + * a single parameter per dimension (either as 1d tensor or as + 2d transformation matrix), which will be replicated for all + batch samples + None will be treated as a translation offset of 0 + matrix : torch.Tensor + the transformation matrix. If other than None: overwrites separate + parameters for :param:`scale`, :param:`rotation` and + :param:`translation` + degree : bool + whether the given rotation is in degrees. Only valid for explicit + rotation parameters + device : str, torch.device + the device, the matrix should be put on + dtype : str, torch.dtype + the datatype, the matrix should have + + Returns + ------- + torch.Tensor + the assembled transformation matrix + + """ + if matrix is None: + matrix = parametrize_matrix(scale=scale, rotation=rotation, + translation=translation, + batchsize=batchsize, + ndim=ndim, + degree=degree, + device=device, + dtype=dtype) + + else: + if not torch.is_tensor(matrix): + matrix = torch.tensor(matrix) + + matrix = matrix.to(dtype=dtype, device=device) + + # batch dimension missing -> Replicate for each sample in batch + if len(matrix.shape) == 2: + matrix = matrix[None].expand(batchsize, -1, -1).clone() + + if matrix.shape == (batchsize, ndim, ndim + 1): + return matrix + elif matrix.shape == (batchsize, ndim + 1, ndim + 1): + return matrix_to_cartesian(matrix) + + raise ValueError( + "Invalid Shape for affine transformation matrix. " + "Got %s but expected %s" % ( + str(tuple(matrix.shape)), + str((batchsize, ndim, ndim + 1)))) + + def affine_point_transform(point_batch: torch.Tensor, matrix_batch: torch.Tensor) -> torch.Tensor: """ diff --git a/rising/utils/affine.py b/rising/utils/affine.py index 9b6da49a..ab61c928 100644 --- a/rising/utils/affine.py +++ b/rising/utils/affine.py @@ -4,10 +4,6 @@ from math import pi from typing import Union, Sequence -from rising.utils.checktype import check_scalar - -AffineParamType = Union[int, float, Sequence, torch.Tensor] - def points_to_homogeneous(batch: torch.Tensor) -> torch.Tensor: """ @@ -137,161 +133,6 @@ def get_batched_eye(batchsize: int, ndim: int, 1, ndim, ndim).expand(batchsize, -1, -1).clone() -def _format_scale(scale: AffineParamType, - batchsize: int, ndim: int, - device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None) -> torch.Tensor: - """ - Formats the given scale parameters to a homogeneous transformation matrix - - Parameters - ---------- - scale : torch.Tensor, int, float - the scale factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a scaling factor of 1 - batchsize : int - the number of samples per batch - ndim : int - the dimensionality of the transform - device : torch.device, str, optional - the device to put the resulting tensor to. Defaults to the default - device - dtype : torch.dtype, str, optional - the dtype of the resulting trensor. Defaults to the default dtype - - Returns - ------- - torch.Tensor - the homogeneous transformation matrix - - """ - - if scale is None: - scale = 1 - - if check_scalar(scale): - - scale = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, - dtype=dtype) * scale - - elif not torch.is_tensor(scale): - scale = torch.tensor(scale, dtype=dtype, device=device) - - # scale must be tensor by now - scale = scale.to(device=device, dtype=dtype) - - # scale is already batched matrix - if scale.size() == (batchsize, ndim, ndim) or scale.size() == (batchsize, ndim, ndim + 1): - return matrix_to_homogeneous(scale) - - # scale is batched matrix with same element for each dimension or just - # not diagonalized - if scale.size() == (batchsize, ndim) or scale.size() == (batchsize,): - new_scale = get_batched_eye(batchsize=batchsize, ndim=ndim, - device=device, dtype=dtype) - - return matrix_to_homogeneous(new_scale * scale.view(batchsize, -1, 1)) - - # scale contains a non-diagonalized form (will be repeated for each batch - # item) - elif scale.size() == (ndim,): - return matrix_to_homogeneous( - torch.diag(scale).view(1, ndim, ndim).expand(batchsize, - -1, -1).clone()) - - # scale contains a diagonalized but not batched matrix - # (will be repeated for each batch item) - elif scale.size() == (ndim, ndim): - return matrix_to_homogeneous( - scale.view(1, ndim, ndim).expand(batchsize, -1, -1).clone()) - - raise ValueError("Unknown shape for scale matrix: %s" - % str(tuple(scale.size()))) - - -def _format_translation(offset: AffineParamType, - batchsize: int, ndim: int, - device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None - ) -> torch.Tensor: - """ - Formats the given translation parameters to a homogeneous transformation - matrix - - Parameters - ---------- - offset : torch.Tensor, int, float - the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a translation offset of 0 - batchsize : int - the number of samples per batch - ndim : int - the dimensionality of the transform - device : torch.device, str, optional - the device to put the resulting tensor to. Defaults to the default - device - dtype : torch.dtype, str, optional - the dtype of the resulting trensor. Defaults to the default dtype - - Returns - ------- - torch.Tensor - the homogeneous transformation matrix - - """ - if offset is None: - offset = 0 - - if check_scalar(offset): - offset = torch.tensor([offset] * ndim, device=device, dtype=dtype) - - elif not torch.is_tensor(offset): - offset = torch.tensor(offset, device=device, dtype=dtype) - - # assumes offset to be tensor from now on - offset = offset.to(device=device, dtype=dtype) - - # translation matrix already built - if offset.size() == (batchsize, ndim + 1, ndim + 1): - return offset - elif offset.size() == (batchsize, ndim, ndim + 1): - return matrix_to_homogeneous(offset) - - # not completely built so far -> bring in shape (batchsize, ndim) - if offset.size() == (batchsize,): - offset = offset.view(-1, 1).expand(-1, ndim).clone() - elif offset.size() == (ndim,): - offset = offset.view(1, -1).expand(batchsize, -1).clone() - elif not offset.size() == (batchsize, ndim): - raise ValueError("Unknown shape for offsets: %s" - % str(tuple(offset.shape))) - - # directly build homogeneous form -> use dim+1 - whole_translation_matrix = get_batched_eye(batchsize=batchsize, - ndim=ndim + 1, device=device, - dtype=dtype) - - whole_translation_matrix[:, :-1, -1] = offset.clone() - return whole_translation_matrix - - def deg_to_rad(angles: Union[torch.Tensor, float, int] ) -> Union[torch.Tensor, float, int]: """ @@ -311,293 +152,6 @@ def deg_to_rad(angles: Union[torch.Tensor, float, int] return angles * pi / 180 -def _format_rotation(rotation: AffineParamType, - batchsize: int, ndim: int, - degree: bool = False, - device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None) -> torch.Tensor: - """ - Formats the given scale parameters to a homogeneous transformation matrix - - Parameters - ---------- - rotation : torch.Tensor, int, float - the rotation factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM(+1) x NDIM(+1)) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a rotation factor of 0 - batchsize : int - the number of samples per batch - ndim : int - the dimensionality of the transform - degree : bool - whether the given rotation(s) are in degrees. - Only valid for rotation parameters, which aren't passed as full - transformation matrix. - device : torch.device, str, optional - the device to put the resulting tensor to. Defaults to the default - device - dtype : torch.dtype, str, optional - the dtype of the resulting trensor. Defaults to the default dtype - - Returns - ------- - torch.Tensor - the homogeneous transformation matrix - - """ - if rotation is None: - rotation = 0 - - num_rot_params = 1 if ndim == 2 else ndim - - if check_scalar(rotation): - rotation = torch.ones(batchsize, num_rot_params, - device=device, dtype=dtype) * rotation - elif not torch.is_tensor(rotation): - rotation = torch.tensor(rotation, device=device, dtype=dtype) - - # assumes rotation to be tensor by now - rotation = rotation.to(device=device, dtype=dtype) - - # already complete - if rotation.size() == (batchsize, ndim, ndim) or rotation.size() == (batchsize, ndim, ndim + 1): - return matrix_to_homogeneous(rotation) - elif rotation.size() == (batchsize, ndim + 1, ndim + 1): - return rotation - - if degree: - rotation = deg_to_rad(rotation) - - # repeat along batch dimension - if rotation.size() == (ndim, ndim) or rotation.size() == (ndim + 1, ndim + 1): - rotation = rotation[None].expand(batchsize, -1, -1).clone() - if rotation.size(-1) == ndim: - rotation = matrix_to_homogeneous(rotation) - - return rotation - # bring it to default size of (batchsize, num_rot_params) - elif rotation.size() == (batchsize,): - rotation = rotation.view(batchsize, 1).expand(-1, num_rot_params).clone() - elif rotation.size() == (num_rot_params,): - rotation = rotation.view(1, num_rot_params).expand(batchsize, - -1).clone() - elif rotation.size() != (batchsize, num_rot_params): - raise ValueError("Invalid shape for rotation parameters: %s" - % (str(tuple(rotation.size())))) - - sin, cos = rotation.sin(), rotation.cos() - - whole_rot_matrix = get_batched_eye(batchsize=batchsize, ndim=ndim, - device=device, dtype=dtype) - - # assemble the actual matrix - if num_rot_params == 1: - whole_rot_matrix[:, 0, 0] = cos[0].clone() - whole_rot_matrix[:, 1, 1] = cos[0].clone() - whole_rot_matrix[:, 0, 1] = (-sin[0]).clone() - whole_rot_matrix[:, 1, 0] = sin[0].clone() - else: - whole_rot_matrix[:, 0, 0] = (cos[:, 0] * cos[:, 1] * cos[:, 2] - - sin[:, 0] * sin[:, 2]).clone() - whole_rot_matrix[:, 0, 1] = (-cos[:, 0] * cos[:, 1] * sin[:, 2] - - sin[:, 0] * cos[:, 2]).clone() - whole_rot_matrix[:, 0, 2] = (cos[:, 0] * sin[:, 1]).clone() - whole_rot_matrix[:, 1, 0] = (sin[:, 0] * cos[:, 1] * cos[:, 2] - + cos[:, 0] * sin[:, 2]).clone() - whole_rot_matrix[:, 1, 1] = (-sin[:, 0] * cos[:, 1] * sin[:, 2] - + cos[:, 0] * cos[:, 2]).clone() - whole_rot_matrix[:, 2, 0] = (-sin[:, 1] * cos[:, 2]).clone() - whole_rot_matrix[:, 2, 1] = (-sin[:, 1] * sin[:, 2]).clone() - whole_rot_matrix[:, 2, 2] = (cos[:, 1]).clone() - - return matrix_to_homogeneous(whole_rot_matrix) - - -def parametrize_matrix(scale: AffineParamType, - rotation: AffineParamType, - translation: AffineParamType, - batchsize: int, ndim: int, - degree: bool = False, - device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None) -> torch.Tensor: - """ - Formats the given scale parameters to a homogeneous transformation matrix - - Parameters - ---------- - scale : torch.Tensor, int, float - the scale factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a scaling factor of 1 - rotation : torch.Tensor, int, float - the rotation factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a rotation factor of 1 - translation : torch.Tensor, int, float - the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a translation offset of 0 - batchsize : int - the number of samples per batch - ndim : int - the dimensionality of the transform - degree : bool - whether the given rotation(s) are in degrees. - Only valid for rotation parameters, which aren't passed as full - transformation matrix. - device : torch.device, str, optional - the device to put the resulting tensor to. Defaults to the default - device - dtype : torch.dtype, str, optional - the dtype of the resulting trensor. Defaults to the default dtype - - Returns - ------- - torch.Tensor - the transformation matrix (of shape (BATCHSIZE x NDIM x NDIM+1) - - """ - scale = _format_scale(scale, batchsize=batchsize, ndim=ndim, - device=device, dtype=dtype) - rotation = _format_rotation(rotation, batchsize=batchsize, ndim=ndim, - degree=degree, device=device, dtype=dtype) - - translation = _format_translation(translation, batchsize=batchsize, - ndim=ndim, device=device, dtype=dtype) - - return torch.bmm(torch.bmm(scale, rotation), translation)[:, :-1] - - -def assemble_matrix_if_necessary(batchsize: int, ndim: int, - scale: AffineParamType, - rotation: AffineParamType, - translation: AffineParamType, - matrix: torch.Tensor, - degree: bool, - device: Union[torch.device, str], - dtype: Union[torch.dtype, str] - ) -> torch.Tensor: - """ - Assembles a matrix, if the matrix is not already given - - Parameters - ---------- - batchsize : int - number of samples per batch - ndim : int - the image dimensionality - scale : torch.Tensor, int, float - the scale factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a scaling factor of 1 - rotation : torch.Tensor, int, float - the rotation factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a rotation factor of 1 - translation : torch.Tensor, int, float - the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a translation offset of 0 - matrix : torch.Tensor - the transformation matrix. If other than None: overwrites separate - parameters for :param:`scale`, :param:`rotation` and - :param:`translation` - degree : bool - whether the given rotation is in degrees. Only valid for explicit - rotation parameters - device : str, torch.device - the device, the matrix should be put on - dtype : str, torch.dtype - the datatype, the matrix should have - - Returns - ------- - torch.Tensor - the assembled transformation matrix - - """ - if matrix is None: - matrix = parametrize_matrix(scale=scale, rotation=rotation, - translation=translation, - batchsize=batchsize, - ndim=ndim, - degree=degree, - device=device, - dtype=dtype) - - else: - if not torch.is_tensor(matrix): - matrix = torch.tensor(matrix) - - matrix = matrix.to(dtype=dtype, device=device) - - # batch dimension missing -> Replicate for each sample in batch - if len(matrix.shape) == 2: - matrix = matrix[None].expand(batchsize, -1, -1).clone() - - if matrix.shape == (batchsize, ndim, ndim + 1): - return matrix - elif matrix.shape == (batchsize, ndim + 1, ndim + 1): - return matrix_to_cartesian(matrix) - - raise ValueError( - "Invalid Shape for affine transformation matrix. " - "Got %s but expected %s" % ( - str(tuple(matrix.shape)), - str((batchsize, ndim, ndim + 1)))) - - def unit_box(n: int, scale: torch.Tensor = None) -> torch.Tensor: """ Create a sclaed version of a unit box diff --git a/tests/transforms/functional/test_affine.py b/tests/transforms/functional/test_affine.py index 8cf519b4..285e7d6d 100644 --- a/tests/transforms/functional/test_affine.py +++ b/tests/transforms/functional/test_affine.py @@ -1,9 +1,9 @@ import unittest import torch from rising.transforms.functional.affine import _check_new_img_size, \ - affine_point_transform, affine_image_transform -from rising.utils.affine import parametrize_matrix, matrix_to_homogeneous, matrix_to_cartesian -from rising.utils.checktype import check_scalar + affine_point_transform, affine_image_transform, parametrize_matrix, \ + create_rotation, create_translation, create_scale, assemble_matrix_if_necessary +from rising.utils.affine import matrix_to_homogeneous, matrix_to_cartesian class AffineTestCase(unittest.TestCase): @@ -142,6 +142,198 @@ def test_affine_image_trafo(self): self.assertTupleEqual(result.shape[2:], target_size) + def test_create_scale(self): + inputs = [ + {'scale': None, 'batchsize': 2, 'ndim': 2}, + {'scale': 2, 'batchsize': 2, 'ndim': 2}, + {'scale': [2, 3], 'batchsize': 3, 'ndim': 2}, + {'scale': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, + {'scale': [[2, 3], [4, 5]], 'batchsize': 3, 'ndim': 2}, + ] + + expectations = [ + torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], + [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), + torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]]]), + torch.tensor([[[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]]]), + torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], + [[3., 0., 0.], [0., 3., 0.], [0., 0., 1.]], + [[4., 0., 0.], [0., 4., 0.], [0., 0., 1.]]]), + torch.tensor([[[2, 3, 0], [4, 5, 0], [0, 0, 1]], + [[2, 3, 0], [4, 5, 0], [0, 0, 1]], + [[2, 3, 0], [4, 5, 0], [0, 0, 1]]]) + + ] + + for inp, exp in zip(inputs, expectations): + with self.subTest(input=inp, expected=exp): + res = create_scale(**inp).to(exp.dtype) + self.assertTrue(torch.allclose(res, exp, atol=1e-6)) + + with self.assertRaises(ValueError): + create_scale([4, 5, 6, 7], batchsize=3, ndim=2) + + def test_create_translation(self): + inputs = [ + {'offset': None, 'batchsize': 2, 'ndim': 2}, + {'offset': 2, 'batchsize': 2, 'ndim': 2}, + {'offset': [2, 3], 'batchsize': 3, 'ndim': 2}, + {'offset': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, + {'offset': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], + 'batchsize': 3, 'ndim': 2}, + {'offset': [[[1, 2, 3], [4, 5, 6]], + [[10, 11, 12], [13, 14, 15]], + [[19, 20, 21], [22, 23, 24]]], + 'batchsize': 3, 'ndim': 2} + + ] + + expectations = [ + torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[1, 0, 0], [0, 1, 0], [0, 0, 1]]]), + torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], + [[1, 0, 2], [0, 1, 2], [0, 0, 1]]]), + torch.tensor([[[1, 0, 2], [0, 1, 3], [0, 0, 1]], + [[1, 0, 2], [0, 1, 3], [0, 0, 1]], + [[1, 0, 2], [0, 1, 3], [0, 0, 1]]]), + torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], + [[1, 0, 3], [0, 1, 3], [0, 0, 1]], + [[1, 0, 4], [0, 1, 4], [0, 0, 1]]]), + torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]), + torch.tensor([[[1, 2, 3], [4, 5, 6], [0, 0, 1]], + [[10, 11, 12], [13, 14, 15], [0, 0, 1]], + [[19, 20, 21], [22, 23, 24], [0, 0, 1]]]) + + ] + + for inp, exp in zip(inputs, expectations): + with self.subTest(input=inp, expected=exp): + res = create_translation(**inp).to(exp.dtype) + self.assertTrue(torch.allclose(res, exp, atol=1e-6)) + + with self.assertRaises(ValueError): + create_translation([4, 5, 6, 7], batchsize=3, ndim=2) + + + def test_format_rotation(self): + inputs = [ + {'rotation': None, 'batchsize': 2, 'ndim': 3}, + {'rotation': 0, 'degree': True, 'batchsize': 2, 'ndim': 2}, + {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 2, 'ndim': 3}, + {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, 'ndim': 2}, + {'rotation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]]], + 'batchsize': 2, 'ndim': 2}, + {'rotation': [[[1, 2, 3], [4, 5, 6]], + [[10, 11, 12], [13, 14, 15]]], + 'batchsize': 2, 'ndim': 2}, + {'rotation': [[1, 2], [3, 4]], 'batchsize': 3, 'ndim': 2, 'degree': False} + + ] + expectations = [ + torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], + [0., 0., 1., 0.], [0., 0., 0., 1.]], + [[1., 0., 0., 0.], [0., 1., 0., 0.], + [0., 0., 1., 0.], [0., 0., 0., 1.]]]), + torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], + [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), + torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], + [0., 0., 1., 0.], [0., 0., 0., 1.]], + [[1., 0., 0., 0.], [0., 1., 0., 0.], + [0., 0., 1., 0.], [0., 0., 0., 1.]]]), + torch.tensor([[[-1, 0, 0], [0, -1, 0], [0, 0, 1]], + [[-1, 0, 0], [0, -1, 0], [0, 0, 1]], + [[-1, 0, 0], [0, -1, 0], [0, 0, 1]]]), + torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + [[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]]]), + torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], + [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]]]), + torch.tensor([[[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], + [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], + [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]]]) + ] + + for inp, exp in zip(inputs, expectations): + with self.subTest(input=inp, expected=exp): + res = create_rotation(**inp).to(exp.dtype) + self.assertTrue(torch.allclose(res, exp, atol=1e-6)) + + with self.assertRaises(ValueError): + create_rotation([4, 5, 6, 7], batchsize=1, ndim=2) + + def test_matrix_parametrization(self): + inputs = [ + {'scale': None, 'translation': None, 'rotation': None, 'batchsize': 2, 'ndim': 2, + 'dtype': torch.float}, + {'scale': [[2, 3], [4, 5]], 'translation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], + 'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, + 'ndim': 2, 'dtype':torch.float} + ] + + expectations = [ + torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], + [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), + + torch.bmm(torch.bmm(torch.tensor([[[2., 3., 0], [4., 5., 0.], [0., 0., 1.]], + [[2., 3., 0.], [4., 5., 0.], [0., 0., 1.]], + [[2., 3., 0.], [4., 5., 0.], [0., 0., 1.]]]), + torch.tensor([[[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]], + [[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]], + [[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]]])), + torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], + [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]], + [[19., 20., 21.], [22., 23., 24.], [0., 0., 1.]]])) + + ] + + for inp, exp in zip(inputs, expectations): + with self.subTest(input=inp, expected=exp): + res = parametrize_matrix(**inp).to(exp.dtype) + self.assertTrue(torch.allclose(res, matrix_to_cartesian(exp), atol=1e-6)) + + def test_necessary_assembly(self): + inputs = [ + {'matrix': None, 'translation': [2, 3], 'ndim':2, 'batchsize': 3, + 'dtype': torch.float}, + {'matrix': [[1., 0., 4.], [0., 1., 5.], [0., 0., 1.]], 'translation': [2, 3], 'ndim': 2, 'batchsize': 3, + 'dtype': torch.float}, + {'matrix': [[1., 0., 4.], [0., 1., 5.]], 'translation': [2, 3], 'ndim': 2, 'batchsize': 3, + 'dtype': torch.float} + + ] + expectations = [ + torch.tensor([[[1., 0., 2.], [0., 1., 3.]], + [[1., 0., 2.], [0., 1., 3.]], + [[1., 0., 2.], [0., 1., 3.]]]), + torch.tensor([[[1., 0., 4.], [0., 1., 5.]], + [[1., 0., 4.], [0., 1., 5.]], + [[1., 0., 4.], [0., 1., 5.]]]), + torch.tensor([[[1., 0., 4.], [0., 1., 5.]], + [[1., 0., 4.], [0., 1., 5.]], + [[1., 0., 4.], [0., 1., 5.]]]) + ] + + for inp, exp in zip(inputs, expectations): + with self.subTest(input=inp, expected=exp): + res = assemble_matrix_if_necessary(**inp, degree=False, + device='cpu', scale=None, rotation=None).to(exp.dtype) + self.assertTrue(torch.allclose(res, exp, atol=1e-6)) + + with self.assertRaises(ValueError): + assemble_matrix_if_necessary(matrix=[1, 2, 3, 4, 5], scale=None, + rotation=None, translation=None, + degree=False, dtype=torch.float, + device='cpu', batchsize=1, ndim=2) + if __name__ == '__main__': unittest.main() diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index bdb73dd2..a825c620 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -98,10 +98,6 @@ def test_affine_subtypes(self): self.assertIsInstance(result, torch.Tensor) self.assertTupleEqual(expected_size, result.shape[-2:]) - self.assertTupleEqual((5, 4), trafos[-1](**sample)['data'].shape[2:]) - - self.assertTupleEqual((5, 4), trafos[-1](**sample)['data'].shape[2:]) - if __name__ == '__main__': unittest.main() diff --git a/tests/transforms/test_compose.py b/tests/transforms/test_compose.py index e88af523..196b7ecd 100644 --- a/tests/transforms/test_compose.py +++ b/tests/transforms/test_compose.py @@ -1,3 +1,4 @@ + import unittest import torch @@ -61,12 +62,12 @@ def __call__(self, *args, **kwargs): trafo_a = trafo_a.to(torch.float32) trafo_b = DummyTrafo(torch.tensor([2.], dtype=torch.float32)) trafo_b = trafo_b.to(torch.float32) - self.assertEquals(trafo_a.tmp.dtype, torch.float32) - self.assertEquals(trafo_b.tmp.dtype, torch.float32) + self.assertEqual(trafo_a.tmp.dtype, torch.float32) + self.assertEqual(trafo_b.tmp.dtype, torch.float32) compose = Compose(trafo_a, trafo_b) compose = compose.to(torch.float64) - self.assertEquals(compose.transforms[0].tmp.dtype, torch.float64) + self.assertEqual(compose.transforms[0].tmp.dtype, torch.float64) def test_wrapping_non_module_trafos(self): class DummyTrafo: diff --git a/tests/utils/test_affine.py b/tests/utils/test_affine.py index e74cb870..c6563b72 100644 --- a/tests/utils/test_affine.py +++ b/tests/utils/test_affine.py @@ -1,9 +1,7 @@ import unittest from rising.utils.affine import points_to_homogeneous, matrix_to_homogeneous, \ matrix_to_cartesian, points_to_cartesian, \ - get_batched_eye, _format_scale, _format_translation, deg_to_rad, \ - _format_rotation, parametrize_matrix, assemble_matrix_if_necessary, \ - unit_box + get_batched_eye, deg_to_rad, unit_box import torch import math @@ -142,85 +140,6 @@ def test_batched_eye(self): for _eye in batched_eye: self.assertTrue(torch.allclose(_eye, non_batched_eye, atol=1e-6)) - def test_format_scale(self): - inputs = [ - {'scale': None, 'batchsize': 2, 'ndim': 2}, - {'scale': 2, 'batchsize': 2, 'ndim': 2}, - {'scale': [2, 3], 'batchsize': 3, 'ndim': 2}, - {'scale': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, - {'scale': [[2, 3], [4, 5]], 'batchsize': 3, 'ndim': 2}, - ] - - expectations = [ - torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], - [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), - torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], - [[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]]]), - torch.tensor([[[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]], - [[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]], - [[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]]]), - torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], - [[3., 0., 0.], [0., 3., 0.], [0., 0., 1.]], - [[4., 0., 0.], [0., 4., 0.], [0., 0., 1.]]]), - torch.tensor([[[2, 3, 0], [4, 5, 0], [0, 0, 1]], - [[2, 3, 0], [4, 5, 0], [0, 0, 1]], - [[2, 3, 0], [4, 5, 0], [0, 0, 1]]]) - - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = _format_scale(**inp).to(exp.dtype) - self.assertTrue(torch.allclose(res, exp, atol=1e-6)) - - with self.assertRaises(ValueError): - _format_scale([4, 5, 6, 7], batchsize=3, ndim=2) - - def test_format_translation(self): - inputs = [ - {'offset': None, 'batchsize': 2, 'ndim': 2}, - {'offset': 2, 'batchsize': 2, 'ndim': 2}, - {'offset': [2, 3], 'batchsize': 3, 'ndim': 2}, - {'offset': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, - {'offset': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], - 'batchsize': 3, 'ndim': 2}, - {'offset': [[[1, 2, 3], [4, 5, 6]], - [[10, 11, 12], [13, 14, 15]], - [[19, 20, 21], [22, 23, 24]]], - 'batchsize': 3, 'ndim': 2} - - ] - - expectations = [ - torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[1, 0, 0], [0, 1, 0], [0, 0, 1]]]), - torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], - [[1, 0, 2], [0, 1, 2], [0, 0, 1]]]), - torch.tensor([[[1, 0, 2], [0, 1, 3], [0, 0, 1]], - [[1, 0, 2], [0, 1, 3], [0, 0, 1]], - [[1, 0, 2], [0, 1, 3], [0, 0, 1]]]), - torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], - [[1, 0, 3], [0, 1, 3], [0, 0, 1]], - [[1, 0, 4], [0, 1, 4], [0, 0, 1]]]), - torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]), - torch.tensor([[[1, 2, 3], [4, 5, 6], [0, 0, 1]], - [[10, 11, 12], [13, 14, 15], [0, 0, 1]], - [[19, 20, 21], [22, 23, 24], [0, 0, 1]]]) - - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = _format_translation(**inp).to(exp.dtype) - self.assertTrue(torch.allclose(res, exp, atol=1e-6)) - - with self.assertRaises(ValueError): - _format_translation([4, 5, 6, 7], batchsize=3, ndim=2) - def test_deg_to_rad(self): inputs = [ torch.tensor([tmp * 45. for tmp in range(9)]), @@ -234,118 +153,6 @@ def test_deg_to_rad(self): with self.subTest(input=inp, expected=exp): self.assertTrue(torch.allclose(deg_to_rad(inp), exp, atol=1e-6)) - def test_format_rotation(self): - inputs = [ - {'rotation': None, 'batchsize': 2, 'ndim': 3}, - {'rotation': 0, 'degree': True, 'batchsize': 2, 'ndim': 2}, - {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 2, 'ndim': 3}, - {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, 'ndim': 2}, - {'rotation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]]], - 'batchsize': 2, 'ndim': 2}, - {'rotation': [[[1, 2, 3], [4, 5, 6]], - [[10, 11, 12], [13, 14, 15]]], - 'batchsize': 2, 'ndim': 2}, - {'rotation': [[1, 2], [3, 4]], 'batchsize': 3, 'ndim': 2, 'degree': False} - - ] - expectations = [ - torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]], - [[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]]]), - torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], - [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), - torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]], - [[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]]]), - torch.tensor([[[-1, 0, 0], [0, -1, 0], [0, 0, 1]], - [[-1, 0, 0], [0, -1, 0], [0, 0, 1]], - [[-1, 0, 0], [0, -1, 0], [0, 0, 1]]]), - torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], - [[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]]]), - torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], - [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]]]), - torch.tensor([[[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], - [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], - [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]]]) - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = _format_rotation(**inp).to(exp.dtype) - self.assertTrue(torch.allclose(res, exp, atol=1e-6)) - - with self.assertRaises(ValueError): - _format_rotation([4, 5, 6, 7], batchsize=1, ndim=2) - - def test_matrix_parametrization(self): - inputs = [ - {'scale': None, 'translation': None, 'rotation': None, 'batchsize': 2, 'ndim': 2, - 'dtype': torch.float}, - {'scale': [[2, 3], [4, 5]], 'translation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], - 'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, - 'ndim': 2, 'dtype':torch.float} - ] - - expectations = [ - torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], - [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), - - torch.bmm(torch.bmm(torch.tensor([[[2., 3., 0], [4., 5., 0.], [0., 0., 1.]], - [[2., 3., 0.], [4., 5., 0.], [0., 0., 1.]], - [[2., 3., 0.], [4., 5., 0.], [0., 0., 1.]]]), - torch.tensor([[[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]], - [[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]], - [[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]]])), - torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], - [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]], - [[19., 20., 21.], [22., 23., 24.], [0., 0., 1.]]])) - - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = parametrize_matrix(**inp).to(exp.dtype) - self.assertTrue(torch.allclose(res, matrix_to_cartesian(exp), atol=1e-6)) - - def test_necessary_assembly(self): - inputs = [ - {'matrix': None, 'translation': [2, 3], 'ndim':2, 'batchsize': 3, - 'dtype': torch.float}, - {'matrix': [[1., 0., 4.], [0., 1., 5.], [0., 0., 1.]], 'translation': [2, 3], 'ndim': 2, 'batchsize': 3, - 'dtype': torch.float}, - {'matrix': [[1., 0., 4.], [0., 1., 5.]], 'translation': [2, 3], 'ndim': 2, 'batchsize': 3, - 'dtype': torch.float} - - ] - expectations = [ - torch.tensor([[[1., 0., 2.], [0., 1., 3.]], - [[1., 0., 2.], [0., 1., 3.]], - [[1., 0., 2.], [0., 1., 3.]]]), - torch.tensor([[[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]]]), - torch.tensor([[[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]]]) - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = assemble_matrix_if_necessary(**inp, degree=False, - device='cpu', scale=None, rotation=None).to(exp.dtype) - self.assertTrue(torch.allclose(res, exp, atol=1e-6)) - - with self.assertRaises(ValueError): - assemble_matrix_if_necessary(matrix=[1, 2, 3, 4, 5], scale=None, - rotation=None, translation=None, - degree=False, dtype=torch.float, - device='cpu', batchsize=1, ndim=2) - def test_unit_box_2d(self): curr_img_size = torch.tensor([2, 3]) box = torch.tensor([[0., 0.], [0., curr_img_size[1]], From 2f461eb248d723ef1964c1e08186463164db1583 Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Tue, 18 Feb 2020 17:30:29 +0000 Subject: [PATCH 05/20] autopep8 fix --- tests/transforms/functional/test_affine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/transforms/functional/test_affine.py b/tests/transforms/functional/test_affine.py index 285e7d6d..198e53a0 100644 --- a/tests/transforms/functional/test_affine.py +++ b/tests/transforms/functional/test_affine.py @@ -221,7 +221,6 @@ def test_create_translation(self): with self.assertRaises(ValueError): create_translation([4, 5, 6, 7], batchsize=3, ndim=2) - def test_format_rotation(self): inputs = [ {'rotation': None, 'batchsize': 2, 'ndim': 3}, From 99b56da09b9733bba9b9c616e6db977d1ca0c2cd Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Tue, 18 Feb 2020 19:59:14 +0100 Subject: [PATCH 06/20] start create matrix refactoring --- rising/transforms/functional/affine.py | 247 ++++++++++--------------- 1 file changed, 99 insertions(+), 148 deletions(-) diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index cc2fd7ca..aab20924 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -1,5 +1,6 @@ import torch import warnings +from torch import Tensor from typing import Union, Sequence from rising.utils.affine import points_to_cartesian, matrix_to_homogeneous, \ @@ -9,13 +10,50 @@ __all__ = [ 'affine_image_transform', - 'affine_point_transform' + 'affine_point_transform', + "create_rotation", + "create_scale", + "create_translation", ] AffineParamType = Union[int, float, Sequence, torch.Tensor] +def expand_affine_param(param: AffineParamType, batchsize: int, ndim: int) -> Tensor: + """ + Bring affine params to shape (batchsize, ndim) + + Parameters + ---------- + param: AffineParamType + affine parameter + batchsize: int + size of batch + ndim: int + number of spatial dimensions + + Returns + ------- + Tensor: + affine params in correct shape + """ + if check_scalar(param): + return torch.tensor([[param] * ndim] * batchsize) + + if not torch.is_tensor(param): + param = torch.tensor(param) + + if not param.ndimension == 2: + if param.shape[0] == ndim: # scalar per dim + param = param.reshape(1, -1) + param = param.expand(batchsize, ndim) + assert all([i == j for i, j in zip(param.shape, (batchsize, ndim))]), \ + (f"Affine param need to have shape (batchsize, ndim)" + f"({(batchsize, ndim)}) but found {param.shape}") + return param + + def create_scale(scale: AffineParamType, batchsize: int, ndim: int, device: Union[torch.device, str] = None, @@ -27,14 +65,13 @@ def create_scale(scale: AffineParamType, ---------- scale : torch.Tensor, int, float the scale factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) * a single parameter (as float or int), which will be replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be + * a parameter per sample, which will be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all + * a parameter per dimension, which will be replicated for all batch samples + * a parameter per sampler per dimension None will be treated as a scaling factor of 1 batchsize : int the number of samples per batch @@ -50,50 +87,15 @@ def create_scale(scale: AffineParamType, ------- torch.Tensor the homogeneous transformation matrix - """ - if scale is None: scale = 1 - if check_scalar(scale): - - scale = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, - dtype=dtype) * scale - - elif not torch.is_tensor(scale): - scale = torch.tensor(scale, dtype=dtype, device=device) - - # scale must be tensor by now - scale = scale.to(device=device, dtype=dtype) - - # scale is already batched matrix - if scale.size() == (batchsize, ndim, ndim) or scale.size() == (batchsize, ndim, ndim + 1): - return matrix_to_homogeneous(scale) - - # scale is batched matrix with same element for each dimension or just - # not diagonalized - if scale.size() == (batchsize, ndim) or scale.size() == (batchsize,): - new_scale = get_batched_eye(batchsize=batchsize, ndim=ndim, - device=device, dtype=dtype) - - return matrix_to_homogeneous(new_scale * scale.view(batchsize, -1, 1)) - - # scale contains a non-diagonalized form (will be repeated for each batch - # item) - elif scale.size() == (ndim,): - return matrix_to_homogeneous( - torch.diag(scale).view(1, ndim, ndim).expand(batchsize, - -1, -1).clone()) - - # scale contains a diagonalized but not batched matrix - # (will be repeated for each batch item) - elif scale.size() == (ndim, ndim): - return matrix_to_homogeneous( - scale.view(1, ndim, ndim).expand(batchsize, -1, -1).clone()) - - raise ValueError("Unknown shape for scale matrix: %s" - % str(tuple(scale.size()))) + scale = expand_affine_param(scale, batchsize, ndim).to( + device=device, dtype=dtype) + scale_matrix = torch.bmm( + get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, dtype=dtype), scale) + return matrix_to_homogeneous(scale_matrix) def create_translation(offset: AffineParamType, @@ -109,15 +111,13 @@ def create_translation(offset: AffineParamType, ---------- offset : torch.Tensor, int, float the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) * a single parameter (as float or int), which will be replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be + * a parameter per sample, which will be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all + * a parameter per dimension, which will be replicated for all batch samples + * a parameter per sampler per dimension None will be treated as a translation offset of 0 batchsize : int the number of samples per batch @@ -137,38 +137,11 @@ def create_translation(offset: AffineParamType, """ if offset is None: offset = 0 - - if check_scalar(offset): - offset = torch.tensor([offset] * ndim, device=device, dtype=dtype) - - elif not torch.is_tensor(offset): - offset = torch.tensor(offset, device=device, dtype=dtype) - - # assumes offset to be tensor from now on - offset = offset.to(device=device, dtype=dtype) - - # translation matrix already built - if offset.size() == (batchsize, ndim + 1, ndim + 1): - return offset - elif offset.size() == (batchsize, ndim, ndim + 1): - return matrix_to_homogeneous(offset) - - # not completely built so far -> bring in shape (batchsize, ndim) - if offset.size() == (batchsize,): - offset = offset.view(-1, 1).expand(-1, ndim).clone() - elif offset.size() == (ndim,): - offset = offset.view(1, -1).expand(batchsize, -1).clone() - elif not offset.size() == (batchsize, ndim): - raise ValueError("Unknown shape for offsets: %s" - % str(tuple(offset.shape))) - - # directly build homogeneous form -> use dim+1 - whole_translation_matrix = get_batched_eye(batchsize=batchsize, - ndim=ndim + 1, device=device, - dtype=dtype) - - whole_translation_matrix[:, :-1, -1] = offset.clone() - return whole_translation_matrix + offset = expand_affine_param(offset, batchsize, ndim).to( + device=device, dtype=dtype) + eye = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, dtype=dtype) + translation_matrix = torch.cat([eye, offset], dim=1) + return matrix_to_homogeneous(translation_matrix) def create_rotation(rotation: AffineParamType, @@ -183,14 +156,13 @@ def create_rotation(rotation: AffineParamType, ---------- rotation : torch.Tensor, int, float the rotation factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM(+1) x NDIM(+1)) * a single parameter (as float or int), which will be replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be + * a parameter per sample, which will be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all + * a parameter per dimension, which will be replicated for all batch samples + * a parameter per sampler per dimension None will be treated as a rotation factor of 0 batchsize : int the number of samples per batch @@ -214,70 +186,49 @@ def create_rotation(rotation: AffineParamType, """ if rotation is None: rotation = 0 - num_rot_params = 1 if ndim == 2 else ndim - if check_scalar(rotation): - rotation = torch.ones(batchsize, num_rot_params, - device=device, dtype=dtype) * rotation - elif not torch.is_tensor(rotation): - rotation = torch.tensor(rotation, device=device, dtype=dtype) - - # assumes rotation to be tensor by now - rotation = rotation.to(device=device, dtype=dtype) - - # already complete - if rotation.size() == (batchsize, ndim, ndim) or rotation.size() == (batchsize, ndim, ndim + 1): - return matrix_to_homogeneous(rotation) - elif rotation.size() == (batchsize, ndim + 1, ndim + 1): - return rotation - - if degree: - rotation = deg_to_rad(rotation) - - # repeat along batch dimension - if rotation.size() == (ndim, ndim) or rotation.size() == (ndim + 1, ndim + 1): - rotation = rotation[None].expand(batchsize, -1, -1).clone() - if rotation.size(-1) == ndim: - rotation = matrix_to_homogeneous(rotation) - - return rotation - # bring it to default size of (batchsize, num_rot_params) - elif rotation.size() == (batchsize,): - rotation = rotation.view(batchsize, 1).expand(-1, num_rot_params).clone() - elif rotation.size() == (num_rot_params,): - rotation = rotation.view(1, num_rot_params).expand(batchsize, - -1).clone() - elif rotation.size() != (batchsize, num_rot_params): - raise ValueError("Invalid shape for rotation parameters: %s" - % (str(tuple(rotation.size())))) - - sin, cos = rotation.sin(), rotation.cos() - - whole_rot_matrix = get_batched_eye(batchsize=batchsize, ndim=ndim, - device=device, dtype=dtype) - - # assemble the actual matrix - if num_rot_params == 1: - whole_rot_matrix[:, 0, 0] = cos[0].clone() - whole_rot_matrix[:, 1, 1] = cos[0].clone() - whole_rot_matrix[:, 0, 1] = (-sin[0]).clone() - whole_rot_matrix[:, 1, 0] = sin[0].clone() - else: - whole_rot_matrix[:, 0, 0] = (cos[:, 0] * cos[:, 1] * cos[:, 2] - - sin[:, 0] * sin[:, 2]).clone() - whole_rot_matrix[:, 0, 1] = (-cos[:, 0] * cos[:, 1] * sin[:, 2] - - sin[:, 0] * cos[:, 2]).clone() - whole_rot_matrix[:, 0, 2] = (cos[:, 0] * sin[:, 1]).clone() - whole_rot_matrix[:, 1, 0] = (sin[:, 0] * cos[:, 1] * cos[:, 2] - + cos[:, 0] * sin[:, 2]).clone() - whole_rot_matrix[:, 1, 1] = (-sin[:, 0] * cos[:, 1] * sin[:, 2] - + cos[:, 0] * cos[:, 2]).clone() - whole_rot_matrix[:, 2, 0] = (-sin[:, 1] * cos[:, 2]).clone() - whole_rot_matrix[:, 2, 1] = (-sin[:, 1] * sin[:, 2]).clone() - whole_rot_matrix[:, 2, 2] = (cos[:, 1]).clone() - - return matrix_to_homogeneous(whole_rot_matrix) + rotation = expand_affine_param(rotation, batchsize, num_rot_params) + matrix_fn = create_rotation_2d if ndim == 2 else create_rotation_3d_zyx + rotation_matrix = torch.stack([matrix_fn(r) for r in rotation]).to( + device=device, dtype=dtype) + return matrix_to_homogeneous(rotation_matrix) + + +def create_rotation_2d(sin: Tensor, cos: Tensor) -> Tensor: + return torch.tensor([[cos.clone(), -sin.clone()], [sin.clone(), cos.clone()]]) + + +def create_rotation_3d_xzy(sin: Tensor, cos: Tensor) -> Tensor: + rot_0 = create_rotation_3d_x(sin[0], cos[0]) + rot_1 = create_rotation_3d_y(sin[1], cos[1]) + rot_2 = create_rotation_3d_z(sin[2], cos[2]) + return (rot_0 @ rot_1) @ rot_2 + + +def create_rotation_3d_zyx(sin: Tensor, cos: Tensor) -> Tensor: + rot_0 = create_rotation_3d_x(sin[0], cos[0]) + rot_1 = create_rotation_3d_y(sin[1], cos[1]) + rot_2 = create_rotation_3d_z(sin[2], cos[2]) + return (rot_2 @ rot_1) @ rot_0 + + +def create_rotation_3d_x(sin: Tensor, cos: Tensor) -> Tensor: + return torch.tensor([[1., 0., 0.], + [0., cos.clone(), -sin.clone()], + [0., sin.clone(), cos.clone()]]) + + +def create_rotation_3d_y(sin: Tensor, cos: Tensor) -> Tensor: + return torch.tensor([[cos.clone(), 0., sin.clone()], + [0., 1., 0.], + [-sin.clone(), 0., cos.clone()]]) + + +def create_rotation_3d_z(sin: Tensor, cos: Tensor) -> Tensor: + return torch.tensor([[cos.clone(), -sin.clone(), 0.] + [sin.clone(), cos.clone(), 0.], + [1., 0., 0.]]) def parametrize_matrix(scale: AffineParamType, From e5672cf6c19ea0f06d90732bab2a5c79b754737d Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sat, 22 Feb 2020 21:50:04 +0100 Subject: [PATCH 07/20] reactor affine param handling --- rising/transforms/affine.py | 57 +++--- rising/transforms/functional/affine.py | 217 +++++++++++++-------- rising/utils/affine.py | 9 +- tests/transforms/functional/test_affine.py | 153 +++++++-------- 4 files changed, 240 insertions(+), 196 deletions(-) diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index 666526eb..b9e32ce0 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -39,45 +39,39 @@ def __init__(self, scale: AffineParamType = None, ---------- scale : torch.Tensor, int, float, optional the scale factor(s). Supported are: - * a full transformation matrix of shape - (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension None will be treated as a scaling factor of 1 rotation : torch.Tensor, int, float, optional the rotation factor(s). Supported are: - * a full transformation matrix of shape - (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension None will be treated as a rotation factor of 1 translation : torch.Tensor, int, float the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension None will be treated as a translation offset of 0 matrix : torch.Tensor, optional if given, overwrites the parameters for :param:`scale`, :param:rotation` and :param:`translation`. - Should be a matrix o shape (BATCHSIZE,) NDIM, NDIM+1. - This matrix represents the whole homogeneous transformation matrix + Should be a matrix of shape [(BATCHSIZE,) NDIM, NDIM(+1)] + This matrix represents the whole transformation matrix keys: Sequence keys which should be augmented grad: bool @@ -144,7 +138,6 @@ def assemble_matrix(self, **data) -> torch.Tensor: the (batched) transformation matrix """ - batchsize = data[self.keys[0]].shape[0] ndim = len(data[self.keys[0]].shape) - 2 # channel and batch dim device = data[self.keys[0]].device @@ -487,8 +480,8 @@ def __init__(self, class StackedAffine(Affine): def __init__( self, - *transforms: Union[Affine, Sequence[Union[Sequence[Affine], - Affine]]], + *transforms: Union[Affine, Sequence[ + Union[Sequence[Affine], Affine]]], keys: Sequence = ('data',), grad: bool = False, output_size: tuple = None, diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index aab20924..002d1d52 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -20,7 +20,7 @@ AffineParamType = Union[int, float, Sequence, torch.Tensor] -def expand_affine_param(param: AffineParamType, batchsize: int, ndim: int) -> Tensor: +def expand_scalar_param(param: AffineParamType, batchsize: int, ndim: int) -> Tensor: """ Bring affine params to shape (batchsize, ndim) @@ -39,19 +39,25 @@ def expand_affine_param(param: AffineParamType, batchsize: int, ndim: int) -> Te affine params in correct shape """ if check_scalar(param): - return torch.tensor([[param] * ndim] * batchsize) + return torch.tensor([[param] * ndim] * batchsize).float() if not torch.is_tensor(param): param = torch.tensor(param) + else: + param = param.clone() - if not param.ndimension == 2: + if not param.ndimension() == 2: if param.shape[0] == ndim: # scalar per dim - param = param.reshape(1, -1) - param = param.expand(batchsize, ndim) + param = param.reshape(1, -1).expand(batchsize, ndim) + elif param.shape[0] == batchsize: # scalar per batch + param = param.reshape(-1, 1).expand(batchsize, ndim) + else: + raise ValueError("Unknown param for expanding. " + f"Found {param} for batchsize {batchsize} and ndim {ndim}") assert all([i == j for i, j in zip(param.shape, (batchsize, ndim))]), \ (f"Affine param need to have shape (batchsize, ndim)" f"({(batchsize, ndim)}) but found {param.shape}") - return param + return param.float() def create_scale(scale: AffineParamType, @@ -91,10 +97,11 @@ def create_scale(scale: AffineParamType, if scale is None: scale = 1 - scale = expand_affine_param(scale, batchsize, ndim).to( + scale = expand_scalar_param(scale, batchsize, ndim).to( device=device, dtype=dtype) - scale_matrix = torch.bmm( - get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, dtype=dtype), scale) + scale_matrix = torch.stack( + [eye * s for eye, s in zip(get_batched_eye( + batchsize=batchsize, ndim=ndim, device=device, dtype=dtype), scale)]) return matrix_to_homogeneous(scale_matrix) @@ -137,10 +144,11 @@ def create_translation(offset: AffineParamType, """ if offset is None: offset = 0 - offset = expand_affine_param(offset, batchsize, ndim).to( + offset = expand_scalar_param(offset, batchsize, ndim).to( device=device, dtype=dtype) - eye = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, dtype=dtype) - translation_matrix = torch.cat([eye, offset], dim=1) + eye_batch = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, dtype=dtype) + translation_matrix = torch.stack([torch.cat([eye, o.view(-1, 1)], dim=1) + for eye, o in zip(eye_batch, offset)]) return matrix_to_homogeneous(translation_matrix) @@ -188,47 +196,120 @@ def create_rotation(rotation: AffineParamType, rotation = 0 num_rot_params = 1 if ndim == 2 else ndim - rotation = expand_affine_param(rotation, batchsize, num_rot_params) - matrix_fn = create_rotation_2d if ndim == 2 else create_rotation_3d_zyx - rotation_matrix = torch.stack([matrix_fn(r) for r in rotation]).to( + rotation = expand_scalar_param(rotation, batchsize, num_rot_params) + if degree: + rotation = deg_to_rad(rotation) + + matrix_fn = create_rotation_2d if ndim == 2 else create_rotation_3d + sin, cos = torch.sin(rotation), torch.cos(rotation) + rotation_matrix = torch.stack([matrix_fn(s, c) for s, c in zip(sin, cos)]).to( device=device, dtype=dtype) return matrix_to_homogeneous(rotation_matrix) def create_rotation_2d(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a 2d rotation matrix + + Parameters + ---------- + sin: Tensor + sin value to use for rotation matrix, [1] + cos: Tensor + cos value to use for rotation matrix, [1] + + Returns + ------- + Tensor + rotation matrix, [2, 2] + """ return torch.tensor([[cos.clone(), -sin.clone()], [sin.clone(), cos.clone()]]) -def create_rotation_3d_xzy(sin: Tensor, cos: Tensor) -> Tensor: - rot_0 = create_rotation_3d_x(sin[0], cos[0]) - rot_1 = create_rotation_3d_y(sin[1], cos[1]) - rot_2 = create_rotation_3d_z(sin[2], cos[2]) - return (rot_0 @ rot_1) @ rot_2 +def create_rotation_3d(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a 3d rotation matrix which sequentially applies the rotation + around axis (rot axis 0 -> rot axis 1 -> rot axis 2) + + Parameters + ---------- + sin: Tensor + sin values to use for the rotation, (axis 0, axis 1, axis 2)[3] + cos: Tensor + cos values to use for the rotation, (axis 0, axis 1, axis 2)[3] + + Returns + ------- + Tensor + rotation matrix, [3, 3] + """ + rot_0 = create_rotation_3d_0(sin[0], cos[0]) + rot_1 = create_rotation_3d_1(sin[1], cos[1]) + rot_2 = create_rotation_3d_2(sin[2], cos[2]) + return rot_2 @ (rot_1 @ rot_0) -def create_rotation_3d_zyx(sin: Tensor, cos: Tensor) -> Tensor: - rot_0 = create_rotation_3d_x(sin[0], cos[0]) - rot_1 = create_rotation_3d_y(sin[1], cos[1]) - rot_2 = create_rotation_3d_z(sin[2], cos[2]) - return (rot_2 @ rot_1) @ rot_0 +def create_rotation_3d_0(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a rotation matrix around the zero-th axis + Parameters + ---------- + sin: Tensor + sin value to use for rotation matrix, [1] + cos: Tensor + cos value to use for rotation matrix, [1] -def create_rotation_3d_x(sin: Tensor, cos: Tensor) -> Tensor: + Returns + ------- + Tensor: + rotation matrix, [3, 3] + """ return torch.tensor([[1., 0., 0.], [0., cos.clone(), -sin.clone()], [0., sin.clone(), cos.clone()]]) -def create_rotation_3d_y(sin: Tensor, cos: Tensor) -> Tensor: +def create_rotation_3d_1(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a rotation matrix around the first axis + + Parameters + ---------- + sin: Tensor + sin value to use for rotation matrix, [1] + cos: Tensor + cos value to use for rotation matrix, [1] + + Returns + ------- + Tensor: + rotation matrix, [3, 3] + """ return torch.tensor([[cos.clone(), 0., sin.clone()], [0., 1., 0.], [-sin.clone(), 0., cos.clone()]]) -def create_rotation_3d_z(sin: Tensor, cos: Tensor) -> Tensor: - return torch.tensor([[cos.clone(), -sin.clone(), 0.] +def create_rotation_3d_2(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a rotation matrix around the second axis + + Parameters + ---------- + sin: Tensor + sin value to use for rotation matrix, [1] + cos: Tensor + cos value to use for rotation matrix, [1] + + Returns + ------- + Tensor: + rotation matrix, [3, 3] + """ + return torch.tensor([[cos.clone(), -sin.clone(), 0.], [sin.clone(), cos.clone(), 0.], - [1., 0., 0.]]) + [0., 0., 1.]]) def parametrize_matrix(scale: AffineParamType, @@ -245,37 +326,33 @@ def parametrize_matrix(scale: AffineParamType, ---------- scale : torch.Tensor, int, float the scale factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) * a single parameter (as float or int), which will be replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be + * a parameter per sample, which will be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all + * a parameter per dimension, which will be replicated for all batch samples + * a parameter per sampler per dimension None will be treated as a scaling factor of 1 rotation : torch.Tensor, int, float the rotation factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) * a single parameter (as float or int), which will be replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be + * a parameter per sample, which will be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all + * a parameter per dimension, which will be replicated for all batch samples + * a parameter per sampler per dimension None will be treated as a rotation factor of 1 translation : torch.Tensor, int, float the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) * a single parameter (as float or int), which will be replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be + * a parameter per sample, which will be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all + * a parameter per dimension, which will be replicated for all batch samples + * a parameter per sampler per dimension None will be treated as a translation offset of 0 batchsize : int the number of samples per batch @@ -294,8 +371,7 @@ def parametrize_matrix(scale: AffineParamType, Returns ------- torch.Tensor - the transformation matrix (of shape (BATCHSIZE x NDIM x NDIM+1) - + the transformation matrix [BATCHSIZE, NDIM, NDIM+1] """ scale = create_scale(scale, batchsize=batchsize, ndim=ndim, device=device, dtype=dtype) @@ -304,7 +380,6 @@ def parametrize_matrix(scale: AffineParamType, translation = create_translation(translation, batchsize=batchsize, ndim=ndim, device=device, dtype=dtype) - return torch.bmm(torch.bmm(scale, rotation), translation)[:, :-1] @@ -328,37 +403,33 @@ def assemble_matrix_if_necessary(batchsize: int, ndim: int, the image dimensionality scale : torch.Tensor, int, float the scale factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) * a single parameter (as float or int), which will be replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be + * a parameter per sample, which will be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all + * a parameter per dimension, which will be replicated for all batch samples + * a parameter per sampler per dimension None will be treated as a scaling factor of 1 rotation : torch.Tensor, int, float the rotation factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) * a single parameter (as float or int), which will be replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be + * a parameter per sample, which will be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all + * a parameter per dimension, which will be replicated for all batch samples + * a parameter per sampler per dimension None will be treated as a rotation factor of 1 translation : torch.Tensor, int, float the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) * a single parameter (as float or int), which will be replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be + * a parameter per sample, which will be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all + * a parameter per dimension, which will be replicated for all batch samples + * a parameter per sampler per dimension None will be treated as a translation offset of 0 matrix : torch.Tensor the transformation matrix. If other than None: overwrites separate @@ -379,24 +450,18 @@ def assemble_matrix_if_necessary(batchsize: int, ndim: int, """ if matrix is None: - matrix = parametrize_matrix(scale=scale, rotation=rotation, - translation=translation, - batchsize=batchsize, - ndim=ndim, - degree=degree, - device=device, - dtype=dtype) - + matrix = parametrize_matrix( + scale=scale, rotation=rotation, translation=translation, + batchsize=batchsize, ndim=ndim, degree=degree, + device=device, dtype=dtype) else: if not torch.is_tensor(matrix): matrix = torch.tensor(matrix) - matrix = matrix.to(dtype=dtype, device=device) # batch dimension missing -> Replicate for each sample in batch if len(matrix.shape) == 2: matrix = matrix[None].expand(batchsize, -1, -1).clone() - if matrix.shape == (batchsize, ndim, ndim + 1): return matrix elif matrix.shape == (batchsize, ndim + 1, ndim + 1): @@ -425,7 +490,6 @@ def affine_point_transform(point_batch: torch.Tensor, ------- torch.Tensor the batch of transformed points in cartesian coordinates) - """ point_batch = points_to_homogeneous(point_batch) matrix_batch = matrix_to_homogeneous(matrix_batch) @@ -447,9 +511,9 @@ def affine_image_transform(image_batch: torch.Tensor, Parameters ---------- image_batch : torch.Tensor - the batch to transform. Should have shape of N x C x (D x) H x W + the batch to transform. Should have shape of [N, C, dims] matrix_batch : torch.Tensor - a batch of affine matrices with shape N x NDIM-1 x NDIM + a batch of affine matrices with shape [N, NDIM, NDIM+1] output_size : Iterable if given, this will be the resulting image size. Defaults to ``None`` adjust_size : bool @@ -486,11 +550,10 @@ def affine_image_transform(image_batch: torch.Tensor, If None of them is set, the resulting image will have the same size as the input image """ - # add batch dimension if necessary if len(matrix_batch.shape) < 3: - matrix_batch = matrix_batch[None, ...].expand(image_batch.size(0), - -1, -1).clone() + matrix_batch = matrix_batch[None, ...].expand( + image_batch.size(0), -1, -1).clone() image_size = image_batch.shape[2:] @@ -501,9 +564,7 @@ def affine_image_transform(image_batch: torch.Tensor, if adjust_size: warnings.warn("Adjust size is mutually exclusive with a " "given output size.", UserWarning) - new_size = output_size - elif adjust_size: new_size = tuple([int(tmp.item()) for tmp in _check_new_img_size(image_size, @@ -539,7 +600,7 @@ def _check_new_img_size(curr_img_size, matrix: torch.Tensor, the size of the current image. If int, it will be used as size for all image dimensions matrix : torch.Tensor - a batch of affine matrices with shape N x NDIM x NDIM + 1 + a batch of affine matrices with shape [N, NDIM, NDIM+1] zero_border : bool whether or not to have a fixed image border at zero diff --git a/rising/utils/affine.py b/rising/utils/affine.py index ab61c928..3b916fe4 100644 --- a/rising/utils/affine.py +++ b/rising/utils/affine.py @@ -34,7 +34,7 @@ def matrix_to_homogeneous(batch: torch.Tensor) -> torch.Tensor: Parameters ---------- batch : torch.Tensor - the batch of matrices to convert + the batch of matrices to convert [N, dim, dim] Returns ------- @@ -46,10 +46,9 @@ def matrix_to_homogeneous(batch: torch.Tensor) -> torch.Tensor: missing = batch.new_zeros(size=(*batch.shape[:-1], 1)) batch = torch.cat([batch, missing], dim=-1) - missing = torch.zeros((batch.size(0), - *[1 for tmp in batch.shape[1:-1]], - batch.size(-1)), - device=batch.device, dtype=batch.dtype) + missing = torch.zeros( + (batch.size(0), *[1 for tmp in batch.shape[1:-1]], batch.size(-1)), + device=batch.device, dtype=batch.dtype) missing[..., -1] = 1 diff --git a/tests/transforms/functional/test_affine.py b/tests/transforms/functional/test_affine.py index 198e53a0..2b947cf2 100644 --- a/tests/transforms/functional/test_affine.py +++ b/tests/transforms/functional/test_affine.py @@ -8,21 +8,19 @@ class AffineTestCase(unittest.TestCase): def test_check_image_size(self): - images = [torch.rand(11, 2, 3, 4, 5), torch.rand(11, 2, 3, 4), torch.rand(11, 2, 3, 3)] - - img_sizes = [ - [3, 4, 5], [3, 4], 3 + images = [ + torch.rand(11, 2, 3, 4, 5), + torch.rand(11, 2, 3, 4), + torch.rand(11, 2, 3, 3), ] + img_sizes = [[3, 4, 5], [3, 4], 3] + scales = [ - torch.tensor([[2., 0., 0.], - [0., 3., 0.], - [0., 0., 4.]]), - torch.tensor([[2., 0.], [0., 3.]]), - torch.tensor([[2., 0.], [0., 3.]]) + torch.tensor([2., 3., 4.]), + torch.tensor([2., 3.]), + torch.tensor([2., 3.]) ] - rots = [[45., 90., 135.], [45.], [45.]] - trans = [[0., 10., 20.], [10., 20.], [10., 20.]] edges = [ @@ -38,9 +36,8 @@ def test_check_image_size(self): ] ] - for img, size, scale, rot, tran, edge_pts in zip(images, img_sizes, - scales, rots, trans, - edges): + for img, size, scale, rot, tran, edge_pts in zip( + images, img_sizes, scales, rots, trans, edges): ndim = scale.size(-1) with self.subTest(ndim=ndim): affine = matrix_to_homogeneous( @@ -53,26 +50,21 @@ def test_check_image_size(self): new_edges = torch.bmm(edge_pts.unsqueeze(0), affine.clone().permute(0, 2, 1)) img_size_zero_border = new_edges.max(dim=1)[0][0] - img_size_non_zero_border = (new_edges.max(dim=1)[0] - - new_edges.min(dim=1)[0])[0] + img_size_non_zero_border = (new_edges.max(dim=1)[0] - new_edges.min(dim=1)[0])[0] fn_result_zero_border = _check_new_img_size( - size, - matrix_to_cartesian( - affine.expand(img.size(0), -1, -1).clone()), + size, matrix_to_cartesian(affine.expand(img.size(0), -1, -1).clone()), zero_border=True, ) fn_result_non_zero_border = _check_new_img_size( - size, - matrix_to_cartesian( - affine.expand(img.size(0), -1, -1).clone()), + size, matrix_to_cartesian(affine.expand(img.size(0), -1, -1).clone()), zero_border=False, ) - self.assertTrue(torch.allclose(img_size_zero_border[:-1], - fn_result_zero_border)) - self.assertTrue(torch.allclose(img_size_non_zero_border[:-1], - fn_result_non_zero_border)) + self.assertTrue(torch.allclose( + img_size_zero_border[:-1], fn_result_zero_border)) + self.assertTrue(torch.allclose( + img_size_non_zero_border[:-1], fn_result_non_zero_border)) def test_affine_point_transform(self): points = [ @@ -111,7 +103,6 @@ def test_affine_point_transform(self): atol=1e-7)) def test_affine_image_trafo(self): - matrix = torch.tensor([[4., 0., 0.], [0., 5., 0.]]) image_batch = torch.zeros(10, 3, 25, 25, dtype=torch.float, device='cpu') @@ -148,7 +139,7 @@ def test_create_scale(self): {'scale': 2, 'batchsize': 2, 'ndim': 2}, {'scale': [2, 3], 'batchsize': 3, 'ndim': 2}, {'scale': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, - {'scale': [[2, 3], [4, 5]], 'batchsize': 3, 'ndim': 2}, + # {'scale': [[2, 3], [4, 5]], 'batchsize': 3, 'ndim': 2}, ] expectations = [ @@ -162,10 +153,9 @@ def test_create_scale(self): torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], [[3., 0., 0.], [0., 3., 0.], [0., 0., 1.]], [[4., 0., 0.], [0., 4., 0.], [0., 0., 1.]]]), - torch.tensor([[[2, 3, 0], [4, 5, 0], [0, 0, 1]], - [[2, 3, 0], [4, 5, 0], [0, 0, 1]], - [[2, 3, 0], [4, 5, 0], [0, 0, 1]]]) - + # torch.tensor([[[2, 3, 0], [4, 5, 0], [0, 0, 1]], + # [[2, 3, 0], [4, 5, 0], [0, 0, 1]], + # [[2, 3, 0], [4, 5, 0], [0, 0, 1]]]) ] for inp, exp in zip(inputs, expectations): @@ -182,14 +172,14 @@ def test_create_translation(self): {'offset': 2, 'batchsize': 2, 'ndim': 2}, {'offset': [2, 3], 'batchsize': 3, 'ndim': 2}, {'offset': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, - {'offset': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], - 'batchsize': 3, 'ndim': 2}, - {'offset': [[[1, 2, 3], [4, 5, 6]], - [[10, 11, 12], [13, 14, 15]], - [[19, 20, 21], [22, 23, 24]]], - 'batchsize': 3, 'ndim': 2} + # {'offset': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + # [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + # [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], + # 'batchsize': 3, 'ndim': 2}, + # {'offset': [[[1, 2, 3], [4, 5, 6]], + # [[10, 11, 12], [13, 14, 15]], + # [[19, 20, 21], [22, 23, 24]]], + # 'batchsize': 3, 'ndim': 2} ] @@ -204,12 +194,12 @@ def test_create_translation(self): torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], [[1, 0, 3], [0, 1, 3], [0, 0, 1]], [[1, 0, 4], [0, 1, 4], [0, 0, 1]]]), - torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]), - torch.tensor([[[1, 2, 3], [4, 5, 6], [0, 0, 1]], - [[10, 11, 12], [13, 14, 15], [0, 0, 1]], - [[19, 20, 21], [22, 23, 24], [0, 0, 1]]]) + # torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + # [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + # [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]), + # torch.tensor([[[1, 2, 3], [4, 5, 6], [0, 0, 1]], + # [[10, 11, 12], [13, 14, 15], [0, 0, 1]], + # [[19, 20, 21], [22, 23, 24], [0, 0, 1]]]) ] @@ -225,15 +215,18 @@ def test_format_rotation(self): inputs = [ {'rotation': None, 'batchsize': 2, 'ndim': 3}, {'rotation': 0, 'degree': True, 'batchsize': 2, 'ndim': 2}, - {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 2, 'ndim': 3}, - {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, 'ndim': 2}, - {'rotation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]]], - 'batchsize': 2, 'ndim': 2}, - {'rotation': [[[1, 2, 3], [4, 5, 6]], - [[10, 11, 12], [13, 14, 15]]], - 'batchsize': 2, 'ndim': 2}, - {'rotation': [[1, 2], [3, 4]], 'batchsize': 3, 'ndim': 2, 'degree': False} + + # TODO: update tests with multiple rotation + # {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 2, 'ndim': 3}, + # {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, 'ndim': 2}, + + # {'rotation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + # [[10, 11, 12], [13, 14, 15], [16, 17, 18]]], + # 'batchsize': 2, 'ndim': 2}, + # {'rotation': [[[1, 2, 3], [4, 5, 6]], + # [[10, 11, 12], [13, 14, 15]]], + # 'batchsize': 2, 'ndim': 2}, + # {'rotation': [[1, 2], [3, 4]], 'batchsize': 3, 'ndim': 2, 'degree': False} ] expectations = [ @@ -243,20 +236,21 @@ def test_format_rotation(self): [0., 0., 1., 0.], [0., 0., 0., 1.]]]), torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), - torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]], - [[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]]]), - torch.tensor([[[-1, 0, 0], [0, -1, 0], [0, 0, 1]], - [[-1, 0, 0], [0, -1, 0], [0, 0, 1]], - [[-1, 0, 0], [0, -1, 0], [0, 0, 1]]]), - torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], - [[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]]]), - torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], - [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]]]), - torch.tensor([[[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], - [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], - [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]]]) + # torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], + # [0., 0., 1., 0.], [0., 0., 0., 1.]], + # [[1., 0., 0., 0.], [0., 1., 0., 0.], + # [0., 0., 1., 0.], [0., 0., 0., 1.]]]), + # torch.tensor([[[-1, 0, 0], [0, -1, 0], [0, 0, 1]], + # [[-1, 0, 0], [0, -1, 0], [0, 0, 1]], + # [[-1, 0, 0], [0, -1, 0], [0, 0, 1]]]), + + # torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + # [[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]]]), + # torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], + # [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]]]), + # torch.tensor([[[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], + # [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], + # [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]]]) ] for inp, exp in zip(inputs, expectations): @@ -271,9 +265,7 @@ def test_matrix_parametrization(self): inputs = [ {'scale': None, 'translation': None, 'rotation': None, 'batchsize': 2, 'ndim': 2, 'dtype': torch.float}, - {'scale': [[2, 3], [4, 5]], 'translation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], + {'scale': [2, 5], 'translation': [9, 18, 27], 'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, 'ndim': 2, 'dtype':torch.float} ] @@ -282,16 +274,15 @@ def test_matrix_parametrization(self): torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), - torch.bmm(torch.bmm(torch.tensor([[[2., 3., 0], [4., 5., 0.], [0., 0., 1.]], - [[2., 3., 0.], [4., 5., 0.], [0., 0., 1.]], - [[2., 3., 0.], [4., 5., 0.], [0., 0., 1.]]]), + torch.bmm(torch.bmm(torch.tensor([[[2., 0., 0], [0., 5., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 5., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 5., 0.], [0., 0., 1.]]]), torch.tensor([[[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]], - [[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]], - [[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]]])), - torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], - [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]], - [[19., 20., 21.], [22., 23., 24.], [0., 0., 1.]]])) - + [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], + [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]]])), + torch.tensor([[[1., 0., 9.], [0., 1., 9.], [0., 0., 1.]], + [[1., 0., 18.], [0., 1., 18.], [0., 0., 1.]], + [[1., 0., 27.], [0., 1., 27.], [0., 0., 1.]]])) ] for inp, exp in zip(inputs, expectations): From 910828dd49c0d980376bcf5f53b24c43420115c5 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sat, 22 Feb 2020 22:35:44 +0100 Subject: [PATCH 08/20] add translation unit --- rising/transforms/affine.py | 101 +++++++++++++++---------- rising/transforms/functional/affine.py | 8 +- tests/transforms/test_affine.py | 10 +++ 3 files changed, 74 insertions(+), 45 deletions(-) diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index b9e32ce0..9dad72aa 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -1,10 +1,12 @@ +import torch +from typing import Sequence, Union, Iterable + from rising.transforms.abstract import BaseTransform from rising.transforms.functional.affine import affine_image_transform, \ AffineParamType, assemble_matrix_if_necessary from rising.utils.affine import matrix_to_homogeneous, matrix_to_cartesian from rising.utils.checktype import check_scalar -import torch -from typing import Sequence, Union, Iterable + __all__ = [ 'Affine', @@ -48,7 +50,8 @@ def __init__(self, scale: AffineParamType = None, * a parameter per sampler per dimension None will be treated as a scaling factor of 1 rotation : torch.Tensor, int, float, optional - the rotation factor(s). Supported are: + the rotation factor(s). The rotation is performed in + consecutive order axis0 -> axis1 (-> axis 2). Supported are: * a single parameter (as float or int), which will be replicated for all dimensions and batch samples * a parameter per sample, which will be @@ -58,7 +61,8 @@ def __init__(self, scale: AffineParamType = None, * a parameter per sampler per dimension None will be treated as a rotation factor of 1 translation : torch.Tensor, int, float - the translation offset(s). Supported are: + the translation offset(s) relative to image (should be in the + range [0, 1]). Supported are: * a single parameter (as float or int), which will be replicated for all dimensions and batch samples * a parameter per sample, which will be @@ -136,7 +140,6 @@ def assemble_matrix(self, **data) -> torch.Tensor: ------- torch.Tensor the (batched) transformation matrix - """ batchsize = data[self.keys[0]].shape[0] ndim = len(data[self.keys[0]].shape) - 2 # channel and batch dim @@ -163,7 +166,6 @@ def forward(self, **data) -> dict: ------- dict dictionary containing the transformed data - """ matrix = self.assemble_matrix(**data) @@ -194,7 +196,6 @@ def __add__(self, other): ------- StackedAffine a stacked affine transformation - """ if not isinstance(other, Affine): other = Affine(matrix=other, keys=self.keys, grad=self.grad, @@ -226,7 +227,6 @@ def __radd__(self, other): ------- StackedAffine a stacked affine transformation - """ if not isinstance(other, Affine): other = Affine(matrix=other, keys=self.keys, grad=self.grad, @@ -267,15 +267,13 @@ def __init__(self, ---------- rotation : torch.Tensor, int, float, optional the rotation factor(s). Supported are: - * a full transformation matrix of shape - (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension None will be treated as a rotation factor of 1 keys: Sequence keys which should be augmented @@ -305,11 +303,6 @@ def __init__(self, making the sampling more resolution agnostic. **kwargs : additional keyword arguments passed to the affine transform - - Warnings - -------- - This transform is not applied around the image center - """ super().__init__(scale=None, rotation=rotation, @@ -336,6 +329,7 @@ def __init__(self, interpolation_mode: str = 'bilinear', padding_mode: str = 'zeros', align_corners: bool = False, + unit: str = 'relative', **kwargs): """ Class Performing an Translation-Only @@ -347,15 +341,13 @@ def __init__(self, ---------- translation : torch.Tensor, int, float the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension None will be treated as a translation offset of 0 keys: Sequence keys which should be augmented @@ -370,7 +362,7 @@ def __init__(self, interpolation_mode : str interpolation mode to calculate output values 'bilinear' | 'nearest'. Default: 'bilinear' - padding_mode : + padding_mode : str padding mode for outside grid values 'zeros' | 'border' | 'reflection'. Default: 'zeros' align_corners : Geometrically, we consider the pixels of the input as @@ -379,9 +371,13 @@ def __init__(self, corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. + unit: str + defines the unit of the translation parameter. + 'pixel': define number of pixels to translate | 'relative': + translation should be in the range [0, 1] and is scaled + with the image size **kwargs : additional keyword arguments passed to the affine transform - """ super().__init__(scale=None, rotation=None, @@ -396,6 +392,29 @@ def __init__(self, padding_mode=padding_mode, align_corners=align_corners, **kwargs) + self.unit = unit + + def assemble_matrix(self, **data) -> torch.Tensor: + """ + Assembles the matrix (and takes care of batching and having it on the + right device and in the correct dtype and dimensionality). + + Parameters + ---------- + **data : + the data to be transformed. Will be used to determine batchsize, + dimensionality, dtype and device + + Returns + ------- + torch.Tensor + the (batched) transformation matrix [N, NDIM, NDIM] + """ + matrix = super().assemble_matrix(**data) + if self.unit.lower() == 'pixel': + img_size = torch.tensor(data[self.keys[0]].shape[2:]).to(matrix) + matrix[..., -1] = matrix[..., -1] / img_size + return matrix class Scale(Affine): @@ -419,15 +438,13 @@ def __init__(self, ---------- scale : torch.Tensor, int, float, optional the scale factor(s). Supported are: - * a full transformation matrix of shape - (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension None will be treated as a scaling factor of 1 keys: Sequence keys which should be augmented diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index 002d1d52..bd173708 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -482,9 +482,11 @@ def affine_point_transform(point_batch: torch.Tensor, Parameters ---------- point_batch : torch.Tensor - a point batch of shape BATCHSIZE x NUM_POINTS x NDIM + a point batch of shape [N, NP, NDIM] NP is the number of points, + N is the batch size, NDIM is the number of spatial dimensions matrix_batch : torch.Tensor - a batch of affine matrices with shape N x NDIM-1 x NDIM + a batch of affine matrices with shape [N, NDIM, NDIM + 1], + N is the batch size and NDIM is the number of spatial dimensions Returns ------- @@ -511,7 +513,7 @@ def affine_image_transform(image_batch: torch.Tensor, Parameters ---------- image_batch : torch.Tensor - the batch to transform. Should have shape of [N, C, dims] + the batch to transform. Should have shape of [N, C, NDIM] matrix_batch : torch.Tensor a batch of affine matrices with shape [N, NDIM, NDIM+1] output_size : Iterable diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index a825c620..bfc3a83c 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -98,6 +98,16 @@ def test_affine_subtypes(self): self.assertIsInstance(result, torch.Tensor) self.assertTupleEqual(expected_size, result.shape[-2:]) + def test_translation_assemble_matrix_with_pixel(self): + trafo = Translate([1, 10, 100], unit='pixel') + sample = {'data': torch.rand(3, 3, 100, 100)} + expected = torch.tensor([[1., 0., 0.01], [0., 1., 0.01], + [1., 0., 0.1], [0., 1., 0.1], + [1., 0., 1.], [0., 1., 1.]]) + + trafo.assemble_matrix(**sample) + self.assertTrue(expected.allclose(expected)) + if __name__ == '__main__': unittest.main() From 2cbf4e97825906c95e58b1fb9ad405bded191be2 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sat, 22 Feb 2020 23:35:15 +0100 Subject: [PATCH 09/20] update notebook and test script --- .github/workflows/notebook_tests.yml | 1 + notebooks/transformations.ipynb | 19 ++++++++++++------- rising/transforms/functional/affine.py | 17 ++++++++++++----- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/.github/workflows/notebook_tests.yml b/.github/workflows/notebook_tests.yml index cac06e76..06c9a460 100644 --- a/.github/workflows/notebook_tests.yml +++ b/.github/workflows/notebook_tests.yml @@ -18,6 +18,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + export TEST_ENV=1 python -m pip install --upgrade pip pip install -U pip wheel; pip install -r requirements/install.txt; diff --git a/notebooks/transformations.ipynb b/notebooks/transformations.ipynb index 09cec93f..acefe207 100644 --- a/notebooks/transformations.ipynb +++ b/notebooks/transformations.ipynb @@ -30,7 +30,12 @@ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", - "TEST=0 # TODO: replace this with environment flag" + "import os\n", + "if 'TEST_ENV' in os.environ:\n", + " TEST_ENV = os.environ['TEST_ENV']\n", + "else:\n", + " TEST_ENV = 0\n", + "print(f\"Running test environment: {bool(TEST)}\")" ] }, { @@ -96,7 +101,7 @@ "%gui qt\n", "import napari\n", "def view_batch(batch):\n", - " if not TEST:\n", + " if not TEST_ENV:\n", " viewer = napari.view_image(batch[\"data\"].cpu().numpy(), name=\"data\")\n", " viewer.add_image(batch[\"mask\"].cpu().numpy(), name=\"mask\", opacity=0.2)" ] @@ -116,8 +121,8 @@ "from rising.transforms import *\n", "\n", "batch = {\n", - " \"data\": torch.from_numpy(img).float()[None],\n", - " \"mask\": torch.from_numpy(mask).long()[None],\n", + " \"data\": torch.from_numpy(img).float()[None, None],\n", + " \"mask\": torch.from_numpy(mask).long()[None, None],\n", "}\n", "\n", "def apply_transform(trafo, batch):\n", @@ -175,7 +180,7 @@ }, "outputs": [], "source": [ - "trafo = Rotate(45, degree=True, adjust_size=False)\n", + "trafo = Rotate([0, 0, 45], degree=True, adjust_size=False)\n", "transformed = apply_transform(trafo, batch)\n", "view_batch(transformed)" ] @@ -191,7 +196,7 @@ }, "outputs": [], "source": [ - "trafo = Translate(0.1, adjust_size=False)\n", + "trafo = Translate([0.1, 0, 0], adjust_size=False)\n", "transformed = apply_transform(trafo, batch)\n", "view_batch(transformed)" ] @@ -266,5 +271,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index bd173708..37f54c90 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -92,7 +92,8 @@ def create_scale(scale: AffineParamType, Returns ------- torch.Tensor - the homogeneous transformation matrix + the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1], N is + the batch size and NDIM is the number of spatial dimensions """ if scale is None: scale = 1 @@ -139,7 +140,8 @@ def create_translation(offset: AffineParamType, Returns ------- torch.Tensor - the homogeneous transformation matrix + the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1], N is + the batch size and NDIM is the number of spatial dimensions """ if offset is None: @@ -189,7 +191,8 @@ def create_rotation(rotation: AffineParamType, Returns ------- torch.Tensor - the homogeneous transformation matrix + the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1], N is + the batch size and NDIM is the number of spatial dimensions """ if rotation is None: @@ -371,7 +374,8 @@ def parametrize_matrix(scale: AffineParamType, Returns ------- torch.Tensor - the transformation matrix [BATCHSIZE, NDIM, NDIM+1] + the transformation matrix [N, NDIM, NDIM+1], N is + the batch size and NDIM is the number of spatial dimensions """ scale = create_scale(scale, batchsize=batchsize, ndim=ndim, device=device, dtype=dtype) @@ -446,7 +450,8 @@ def assemble_matrix_if_necessary(batchsize: int, ndim: int, Returns ------- torch.Tensor - the assembled transformation matrix + the assembled transformation matrix [N, NDIM, NDIM+1], N is + the batch size and NDIM is the number of spatial dimensions """ if matrix is None: @@ -492,6 +497,8 @@ def affine_point_transform(point_batch: torch.Tensor, ------- torch.Tensor the batch of transformed points in cartesian coordinates) + [N, NP, NDIM] NP is the number of points, N is the batch size, + NDIM is the number of spatial dimensions """ point_batch = points_to_homogeneous(point_batch) matrix_batch = matrix_to_homogeneous(matrix_batch) From f367ed56da4fdfd6b52c4a88cc1fdf3f99249e42 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sat, 22 Feb 2020 23:39:37 +0100 Subject: [PATCH 10/20] fix notebook --- notebooks/transformations.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/transformations.ipynb b/notebooks/transformations.ipynb index acefe207..4d518276 100644 --- a/notebooks/transformations.ipynb +++ b/notebooks/transformations.ipynb @@ -35,7 +35,7 @@ " TEST_ENV = os.environ['TEST_ENV']\n", "else:\n", " TEST_ENV = 0\n", - "print(f\"Running test environment: {bool(TEST)}\")" + "print(f\"Running test environment: {bool(TEST_ENV)}\")" ] }, { From 67404258171a02d9650ed91abba4d02731f7396b Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 23 Feb 2020 00:20:22 +0100 Subject: [PATCH 11/20] adjust transformations to match behavior on images --- notebooks/transformations.ipynb | 16 +++++--- rising/transforms/functional/affine.py | 44 +++++++++++++++++----- tests/transforms/functional/test_affine.py | 15 +++++--- tests/transforms/test_affine.py | 10 ++--- 4 files changed, 58 insertions(+), 27 deletions(-) diff --git a/notebooks/transformations.ipynb b/notebooks/transformations.ipynb index 4d518276..675f3678 100644 --- a/notebooks/transformations.ipynb +++ b/notebooks/transformations.ipynb @@ -30,6 +30,7 @@ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", + "%gui qt\n", "import os\n", "if 'TEST_ENV' in os.environ:\n", " TEST_ENV = os.environ['TEST_ENV']\n", @@ -98,12 +99,15 @@ }, "outputs": [], "source": [ - "%gui qt\n", - "import napari\n", - "def view_batch(batch):\n", - " if not TEST_ENV:\n", - " viewer = napari.view_image(batch[\"data\"].cpu().numpy(), name=\"data\")\n", - " viewer.add_image(batch[\"mask\"].cpu().numpy(), name=\"mask\", opacity=0.2)" + "if TEST_ENV:\n", + " def view_batch(batch):\n", + " pass\n", + "else:\n", + " %gui qt\n", + " import napari\n", + " def view_batch(batch):\n", + " viewer = napari.view_image(batch[\"data\"].cpu().numpy(), name=\"data\")\n", + " viewer.add_image(batch[\"mask\"].cpu().numpy(), name=\"mask\", opacity=0.2)" ] }, { diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index 37f54c90..fead1886 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -63,7 +63,8 @@ def expand_scalar_param(param: AffineParamType, batchsize: int, ndim: int) -> Te def create_scale(scale: AffineParamType, batchsize: int, ndim: int, device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None) -> torch.Tensor: + dtype: Union[torch.dtype, str] = None, + image_transform: bool = True) -> torch.Tensor: """ Formats the given scale parameters to a homogeneous transformation matrix @@ -88,6 +89,10 @@ def create_scale(scale: AffineParamType, device dtype : torch.dtype, str, optional the dtype of the resulting trensor. Defaults to the default dtype + image_transform: bool + inverts the scale matrix to match expected behavior when applied to + an image, e.g. scale>1 increases the size of an image but decrease + the size of an grid Returns ------- @@ -100,6 +105,8 @@ def create_scale(scale: AffineParamType, scale = expand_scalar_param(scale, batchsize, ndim).to( device=device, dtype=dtype) + if image_transform: + scale = 1 / scale scale_matrix = torch.stack( [eye * s for eye, s in zip(get_batched_eye( batchsize=batchsize, ndim=ndim, device=device, dtype=dtype), scale)]) @@ -109,8 +116,8 @@ def create_scale(scale: AffineParamType, def create_translation(offset: AffineParamType, batchsize: int, ndim: int, device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None - ) -> torch.Tensor: + dtype: Union[torch.dtype, str] = None, + image_transform: bool = True) -> torch.Tensor: """ Formats the given translation parameters to a homogeneous transformation matrix @@ -136,13 +143,16 @@ def create_translation(offset: AffineParamType, device dtype : torch.dtype, str, optional the dtype of the resulting trensor. Defaults to the default dtype + image_transform: bool + inverts the translation matrix to match expected behavior when applied + to an image, e.g. translation > 0 should move the image in the + positive direction of an axis but the grid in the negative direction Returns ------- torch.Tensor the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1], N is the batch size and NDIM is the number of spatial dimensions - """ if offset is None: offset = 0 @@ -151,6 +161,8 @@ def create_translation(offset: AffineParamType, eye_batch = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, dtype=dtype) translation_matrix = torch.stack([torch.cat([eye, o.view(-1, 1)], dim=1) for eye, o in zip(eye_batch, offset)]) + if image_transform: + translation_matrix[..., -1] = -translation_matrix[..., -1] return matrix_to_homogeneous(translation_matrix) @@ -321,7 +333,9 @@ def parametrize_matrix(scale: AffineParamType, batchsize: int, ndim: int, degree: bool = False, device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None) -> torch.Tensor: + dtype: Union[torch.dtype, str] = None, + image_transform: bool = True, + ) -> torch.Tensor: """ Formats the given scale parameters to a homogeneous transformation matrix @@ -370,6 +384,10 @@ def parametrize_matrix(scale: AffineParamType, device dtype : torch.dtype, str, optional the dtype of the resulting trensor. Defaults to the default dtype + image_transform: bool + adjusts transformation matrices such that they match the expected + behavior on images (see :func:`create_scale` and + :func:`create_translation` for more info) Returns ------- @@ -378,12 +396,13 @@ def parametrize_matrix(scale: AffineParamType, the batch size and NDIM is the number of spatial dimensions """ scale = create_scale(scale, batchsize=batchsize, ndim=ndim, - device=device, dtype=dtype) + device=device, dtype=dtype, + image_transform=image_transform) rotation = create_rotation(rotation, batchsize=batchsize, ndim=ndim, degree=degree, device=device, dtype=dtype) - translation = create_translation(translation, batchsize=batchsize, - ndim=ndim, device=device, dtype=dtype) + ndim=ndim, device=device, dtype=dtype, + image_transform=image_transform) return torch.bmm(torch.bmm(scale, rotation), translation)[:, :-1] @@ -394,7 +413,8 @@ def assemble_matrix_if_necessary(batchsize: int, ndim: int, matrix: torch.Tensor, degree: bool, device: Union[torch.device, str], - dtype: Union[torch.dtype, str] + dtype: Union[torch.dtype, str], + image_transform: bool = True, ) -> torch.Tensor: """ Assembles a matrix, if the matrix is not already given @@ -446,6 +466,10 @@ def assemble_matrix_if_necessary(batchsize: int, ndim: int, the device, the matrix should be put on dtype : str, torch.dtype the datatype, the matrix should have + image_transform: bool + adjusts transformation matrices such that they match the expected + behavior on images (see :func:`create_scale` and + :func:`create_translation` for more info) Returns ------- @@ -458,7 +482,7 @@ def assemble_matrix_if_necessary(batchsize: int, ndim: int, matrix = parametrize_matrix( scale=scale, rotation=rotation, translation=translation, batchsize=batchsize, ndim=ndim, degree=degree, - device=device, dtype=dtype) + device=device, dtype=dtype, image_transform=image_transform) else: if not torch.is_tensor(matrix): matrix = torch.tensor(matrix) diff --git a/tests/transforms/functional/test_affine.py b/tests/transforms/functional/test_affine.py index 2b947cf2..50a0a7a4 100644 --- a/tests/transforms/functional/test_affine.py +++ b/tests/transforms/functional/test_affine.py @@ -43,7 +43,8 @@ def test_check_image_size(self): affine = matrix_to_homogeneous( parametrize_matrix(scale=scale, rotation=rot, translation=tran, degree=True, - batchsize=1, ndim=ndim, dtype=torch.float)) + batchsize=1, ndim=ndim, dtype=torch.float, + image_transform=False)) edge_pts = torch.tensor(edge_pts, dtype=torch.float) img = img.to(torch.float) @@ -78,7 +79,8 @@ def test_affine_point_transform(self): rotation=[0, 0, 90], degree=True, batchsize=1, ndim=3, dtype=torch.float, - device='cpu') + device='cpu', + image_transform=False) ] expected = [ [[0, 5], [1, 0]], @@ -160,7 +162,7 @@ def test_create_scale(self): for inp, exp in zip(inputs, expectations): with self.subTest(input=inp, expected=exp): - res = create_scale(**inp).to(exp.dtype) + res = create_scale(**inp, image_transform=False).to(exp.dtype) self.assertTrue(torch.allclose(res, exp, atol=1e-6)) with self.assertRaises(ValueError): @@ -205,7 +207,7 @@ def test_create_translation(self): for inp, exp in zip(inputs, expectations): with self.subTest(input=inp, expected=exp): - res = create_translation(**inp).to(exp.dtype) + res = create_translation(**inp, image_transform=False).to(exp.dtype) self.assertTrue(torch.allclose(res, exp, atol=1e-6)) with self.assertRaises(ValueError): @@ -287,7 +289,7 @@ def test_matrix_parametrization(self): for inp, exp in zip(inputs, expectations): with self.subTest(input=inp, expected=exp): - res = parametrize_matrix(**inp).to(exp.dtype) + res = parametrize_matrix(**inp, image_transform=False).to(exp.dtype) self.assertTrue(torch.allclose(res, matrix_to_cartesian(exp), atol=1e-6)) def test_necessary_assembly(self): @@ -315,7 +317,8 @@ def test_necessary_assembly(self): for inp, exp in zip(inputs, expectations): with self.subTest(input=inp, expected=exp): res = assemble_matrix_if_necessary(**inp, degree=False, - device='cpu', scale=None, rotation=None).to(exp.dtype) + device='cpu', scale=None, rotation=None, + image_transform=False).to(exp.dtype) self.assertTrue(torch.allclose(res, exp, atol=1e-6)) with self.assertRaises(ValueError): diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index bfc3a83c..ff48165d 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -81,13 +81,13 @@ def test_affine_subtypes(self): sample = {'data': torch.rand(1, 3, 25, 30)} trafos = [ - Scale([2, 3], adjust_size=True), + Scale([5, 3], adjust_size=True), Resize([50, 90]), Rotate([90], adjust_size=True, degree=True), ] expected_sizes = [ - (50, 90), + (5, 10), (50, 90), (30, 25), ] @@ -101,9 +101,9 @@ def test_affine_subtypes(self): def test_translation_assemble_matrix_with_pixel(self): trafo = Translate([1, 10, 100], unit='pixel') sample = {'data': torch.rand(3, 3, 100, 100)} - expected = torch.tensor([[1., 0., 0.01], [0., 1., 0.01], - [1., 0., 0.1], [0., 1., 0.1], - [1., 0., 1.], [0., 1., 1.]]) + expected = torch.tensor([[1., 0., -0.01], [0., 1., -0.01], + [1., 0., -0.1], [0., 1., -0.1], + [1., 0., -1.], [0., 1., -1.]]) trafo.assemble_matrix(**sample) self.assertTrue(expected.allclose(expected)) From 53c65cfbd03d29aef965b52fbd70603f94139469 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 23 Feb 2020 00:33:02 +0100 Subject: [PATCH 12/20] update notebook tests --- .github/workflows/notebook_tests.yml | 4 ++-- notebooks/transformations.ipynb | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/notebook_tests.yml b/.github/workflows/notebook_tests.yml index 06c9a460..3c1a6634 100644 --- a/.github/workflows/notebook_tests.yml +++ b/.github/workflows/notebook_tests.yml @@ -9,7 +9,8 @@ jobs: max-parallel: 4 matrix: python-version: [3.7] - + env: + TEST_ENV: TRUE steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} @@ -18,7 +19,6 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - export TEST_ENV=1 python -m pip install --upgrade pip pip install -U pip wheel; pip install -r requirements/install.txt; diff --git a/notebooks/transformations.ipynb b/notebooks/transformations.ipynb index 675f3678..2abe6d3d 100644 --- a/notebooks/transformations.ipynb +++ b/notebooks/transformations.ipynb @@ -33,7 +33,7 @@ "%gui qt\n", "import os\n", "if 'TEST_ENV' in os.environ:\n", - " TEST_ENV = os.environ['TEST_ENV']\n", + " TEST_ENV = os.environ['TEST_ENV'].lower() == \"true\"\n", "else:\n", " TEST_ENV = 0\n", "print(f\"Running test environment: {bool(TEST_ENV)}\")" From ffc118bdbebcaa31d3cda478f9fa385cc31200d1 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 23 Feb 2020 11:56:23 +0100 Subject: [PATCH 13/20] introduce base affine --- rising/transforms/affine.py | 215 ++++++++++++++------- rising/transforms/functional/affine.py | 99 +--------- tests/transforms/functional/test_affine.py | 37 +--- tests/transforms/test_affine.py | 28 +++ 4 files changed, 180 insertions(+), 199 deletions(-) diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index 9dad72aa..a84c3bae 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -3,7 +3,7 @@ from rising.transforms.abstract import BaseTransform from rising.transforms.functional.affine import affine_image_transform, \ - AffineParamType, assemble_matrix_if_necessary + AffineParamType, parametrize_matrix from rising.utils.affine import matrix_to_homogeneous, matrix_to_cartesian from rising.utils.checktype import check_scalar @@ -19,13 +19,10 @@ class Affine(BaseTransform): - def __init__(self, scale: AffineParamType = None, - rotation: AffineParamType = None, - translation: AffineParamType = None, - matrix: torch.Tensor = None, + def __init__(self, + matrix: Union[torch.Tensor, Sequence[Sequence[float]]] = None, keys: Sequence = ('data',), grad: bool = False, - degree: bool = False, output_size: tuple = None, adjust_size: bool = False, interpolation_mode: str = 'bilinear', @@ -39,38 +36,6 @@ def __init__(self, scale: AffineParamType = None, Parameters ---------- - scale : torch.Tensor, int, float, optional - the scale factor(s). Supported are: - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a parameter per sample, which will be - replicated for all dimensions - * a parameter per dimension, which will be replicated for all - batch samples - * a parameter per sampler per dimension - None will be treated as a scaling factor of 1 - rotation : torch.Tensor, int, float, optional - the rotation factor(s). The rotation is performed in - consecutive order axis0 -> axis1 (-> axis 2). Supported are: - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a parameter per sample, which will be - replicated for all dimensions - * a parameter per dimension, which will be replicated for all - batch samples - * a parameter per sampler per dimension - None will be treated as a rotation factor of 1 - translation : torch.Tensor, int, float - the translation offset(s) relative to image (should be in the - range [0, 1]). Supported are: - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a parameter per sample, which will be - replicated for all dimensions - * a parameter per dimension, which will be replicated for all - batch samples - * a parameter per sampler per dimension - None will be treated as a translation offset of 0 matrix : torch.Tensor, optional if given, overwrites the parameters for :param:`scale`, :param:rotation` and :param:`translation`. @@ -80,10 +45,6 @@ def __init__(self, scale: AffineParamType = None, keys which should be augmented grad: bool enable gradient computation inside transformation - degree : bool - whether the given rotation(s) are in degrees. - Only valid for rotation parameters, which aren't passed as full - transformation matrix. output_size : Iterable if given, this will be the resulting image size. Defaults to ``None`` @@ -104,21 +65,12 @@ def __init__(self, scale: AffineParamType = None, making the sampling more resolution agnostic. **kwargs : additional keyword arguments passed to the affine transform - - Notes - ----- - If a :param:`matrix` is specified, it overwrites all arguments given - for :param:`scale`, :param:rotation` and :param:`translation` """ super().__init__(augment_fn=affine_image_transform, keys=keys, grad=grad, **kwargs) - self.scale = scale - self.rotation = rotation - self.translation = translation self.matrix = matrix - self.degree = degree self.output_size = output_size self.adjust_size = adjust_size self.interpolation_mode = interpolation_mode @@ -141,17 +93,30 @@ def assemble_matrix(self, **data) -> torch.Tensor: torch.Tensor the (batched) transformation matrix """ + if self.matrix is None: + raise ValueError("Matrix needs to be initialized or overwritten.") + if not torch.is_tensor(self.matrix): + self.matrix = torch.tensor(self.matrix) + self.matrix = self.matrix.to(data[self.keys[0]]) + batchsize = data[self.keys[0]].shape[0] ndim = len(data[self.keys[0]].shape) - 2 # channel and batch dim - device = data[self.keys[0]].device - dtype = data[self.keys[0]].dtype - matrix = assemble_matrix_if_necessary( - batchsize, ndim, scale=self.scale, rotation=self.rotation, - translation=self.translation, matrix=self.matrix, - degree=self.degree, device=device, dtype=dtype) - - return matrix + # batch dimension missing -> Replicate for each sample in batch + if len(self.matrix.shape) == 2: + self.matrix = self.matrix[None].expand(batchsize, -1, -1).clone() + if self.matrix.shape == (batchsize, ndim, ndim + 1): + return self.matrix + elif self.matrix.shape == (batchsize, ndim, ndim): + return matrix_to_homogeneous(self.matrix)[:, :-1] + elif self.matrix.shape == (batchsize, ndim + 1, ndim + 1): + return matrix_to_cartesian(self.matrix) + + raise ValueError( + "Invalid Shape for affine transformation matrix. " + "Got %s but expected %s" % ( + str(tuple(self.matrix.shape)), + str((batchsize, ndim, ndim + 1)))) def forward(self, **data) -> dict: """ @@ -245,7 +210,129 @@ def __radd__(self, other): **other.kwargs) -class Rotate(Affine): +class BaseAffine(Affine): + def __init__(self, + scale: AffineParamType = None, + rotation: AffineParamType = None, + translation: AffineParamType = None, + degree: bool = False, + image_transform: bool = True, + keys: Sequence = ('data',), + grad: bool = False, + output_size: tuple = None, + adjust_size: bool = False, + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + **kwargs, + ): + """ + Class performing a basic Affine Transformation on a given sample dict. + The transformation will be applied to all the dict-entries specified + in :attr:`keys`. + + Parameters + ---------- + scale : torch.Tensor, int, float, optional + the scale factor(s). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a scaling factor of 1 + rotation : torch.Tensor, int, float, optional + the rotation factor(s). The rotation is performed in + consecutive order axis0 -> axis1 (-> axis 2). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a rotation factor of 1 + translation : torch.Tensor, int, float + the translation offset(s) relative to image (should be in the + range [0, 1]). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a translation offset of 0 + keys: Sequence + keys which should be augmented + grad: bool + enable gradient computation inside transformation + degree : bool + whether the given rotation(s) are in degrees. + Only valid for rotation parameters, which aren't passed as full + transformation matrix. + output_size : Iterable + if given, this will be the resulting image size. + Defaults to ``None`` + adjust_size : bool + if True, the resulting image size will be calculated dynamically + to ensure that the whole image fits. + interpolation_mode : str + interpolation mode to calculate output values + 'bilinear' | 'nearest'. Default: 'bilinear' + padding_mode : + padding mode for outside grid values + 'zeros' | 'border' | 'reflection'. Default: 'zeros' + align_corners : Geometrically, we consider the pixels of the input as + squares rather than points. If set to True, the extrema (-1 and 1) + are considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + **kwargs : + additional keyword arguments passed to the affine transform + """ + super().__init__(keys=keys, grad=grad, output_size=output_size, + adjust_size=adjust_size, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + **kwargs) + self.scale = scale + self.rotation = rotation + self.translation = translation + self.degree = degree + self.image_transform = image_transform + + def assemble_matrix(self, **data) -> torch.Tensor: + """ + Assembles the matrix (and takes care of batching and having it on the + right device and in the correct dtype and dimensionality). + + Parameters + ---------- + **data : + the data to be transformed. Will be used to determine batchsize, + dimensionality, dtype and device + + Returns + ------- + torch.Tensor + the (batched) transformation matrix + """ + batchsize = data[self.keys[0]].shape[0] + ndim = len(data[self.keys[0]].shape) - 2 # channel and batch dim + device = data[self.keys[0]].device + dtype = data[self.keys[0]].dtype + + self.matrix = parametrize_matrix( + scale=self.scale, rotation=self.rotation, translation=self.translation, + batchsize=batchsize, ndim=ndim, degree=self.degree, + device=device, dtype=dtype, image_transform=self.image_transform) + return self.matrix + + +class Rotate(BaseAffine): def __init__(self, rotation: AffineParamType, keys: Sequence = ('data',), @@ -319,7 +406,7 @@ def __init__(self, **kwargs) -class Translate(Affine): +class Translate(BaseAffine): def __init__(self, translation: AffineParamType, keys: Sequence = ('data',), @@ -417,7 +504,7 @@ def assemble_matrix(self, **data) -> torch.Tensor: return matrix -class Scale(Affine): +class Scale(BaseAffine): def __init__(self, scale: AffineParamType, keys: Sequence = ('data',), @@ -556,9 +643,7 @@ def __init__( [trafo if isinstance(trafo, Affine) else Affine(matrix=trafo) for trafo in transforms]) - super().__init__(matrix=None, - scale=None, rotation=None, translation=None, - keys=keys, grad=grad, degree=False, + super().__init__(keys=keys, grad=grad, output_size=output_size, adjust_size=adjust_size, interpolation_mode=interpolation_mode, padding_mode=padding_mode, diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index fead1886..943f2107 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -4,7 +4,7 @@ from typing import Union, Sequence from rising.utils.affine import points_to_cartesian, matrix_to_homogeneous, \ - points_to_homogeneous, unit_box, get_batched_eye, deg_to_rad, matrix_to_cartesian + points_to_homogeneous, unit_box, get_batched_eye, deg_to_rad from rising.utils.checktype import check_scalar @@ -406,103 +406,6 @@ def parametrize_matrix(scale: AffineParamType, return torch.bmm(torch.bmm(scale, rotation), translation)[:, :-1] -def assemble_matrix_if_necessary(batchsize: int, ndim: int, - scale: AffineParamType, - rotation: AffineParamType, - translation: AffineParamType, - matrix: torch.Tensor, - degree: bool, - device: Union[torch.device, str], - dtype: Union[torch.dtype, str], - image_transform: bool = True, - ) -> torch.Tensor: - """ - Assembles a matrix, if the matrix is not already given - - Parameters - ---------- - batchsize : int - number of samples per batch - ndim : int - the image dimensionality - scale : torch.Tensor, int, float - the scale factor(s). Supported are: - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a parameter per sample, which will be - replicated for all dimensions - * a parameter per dimension, which will be replicated for all - batch samples - * a parameter per sampler per dimension - None will be treated as a scaling factor of 1 - rotation : torch.Tensor, int, float - the rotation factor(s). Supported are: - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a parameter per sample, which will be - replicated for all dimensions - * a parameter per dimension, which will be replicated for all - batch samples - * a parameter per sampler per dimension - None will be treated as a rotation factor of 1 - translation : torch.Tensor, int, float - the translation offset(s). Supported are: - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a parameter per sample, which will be - replicated for all dimensions - * a parameter per dimension, which will be replicated for all - batch samples - * a parameter per sampler per dimension - None will be treated as a translation offset of 0 - matrix : torch.Tensor - the transformation matrix. If other than None: overwrites separate - parameters for :param:`scale`, :param:`rotation` and - :param:`translation` - degree : bool - whether the given rotation is in degrees. Only valid for explicit - rotation parameters - device : str, torch.device - the device, the matrix should be put on - dtype : str, torch.dtype - the datatype, the matrix should have - image_transform: bool - adjusts transformation matrices such that they match the expected - behavior on images (see :func:`create_scale` and - :func:`create_translation` for more info) - - Returns - ------- - torch.Tensor - the assembled transformation matrix [N, NDIM, NDIM+1], N is - the batch size and NDIM is the number of spatial dimensions - - """ - if matrix is None: - matrix = parametrize_matrix( - scale=scale, rotation=rotation, translation=translation, - batchsize=batchsize, ndim=ndim, degree=degree, - device=device, dtype=dtype, image_transform=image_transform) - else: - if not torch.is_tensor(matrix): - matrix = torch.tensor(matrix) - matrix = matrix.to(dtype=dtype, device=device) - - # batch dimension missing -> Replicate for each sample in batch - if len(matrix.shape) == 2: - matrix = matrix[None].expand(batchsize, -1, -1).clone() - if matrix.shape == (batchsize, ndim, ndim + 1): - return matrix - elif matrix.shape == (batchsize, ndim + 1, ndim + 1): - return matrix_to_cartesian(matrix) - - raise ValueError( - "Invalid Shape for affine transformation matrix. " - "Got %s but expected %s" % ( - str(tuple(matrix.shape)), - str((batchsize, ndim, ndim + 1)))) - - def affine_point_transform(point_batch: torch.Tensor, matrix_batch: torch.Tensor) -> torch.Tensor: """ diff --git a/tests/transforms/functional/test_affine.py b/tests/transforms/functional/test_affine.py index 50a0a7a4..544b9b7d 100644 --- a/tests/transforms/functional/test_affine.py +++ b/tests/transforms/functional/test_affine.py @@ -2,7 +2,7 @@ import torch from rising.transforms.functional.affine import _check_new_img_size, \ affine_point_transform, affine_image_transform, parametrize_matrix, \ - create_rotation, create_translation, create_scale, assemble_matrix_if_necessary + create_rotation, create_translation, create_scale from rising.utils.affine import matrix_to_homogeneous, matrix_to_cartesian @@ -292,41 +292,6 @@ def test_matrix_parametrization(self): res = parametrize_matrix(**inp, image_transform=False).to(exp.dtype) self.assertTrue(torch.allclose(res, matrix_to_cartesian(exp), atol=1e-6)) - def test_necessary_assembly(self): - inputs = [ - {'matrix': None, 'translation': [2, 3], 'ndim':2, 'batchsize': 3, - 'dtype': torch.float}, - {'matrix': [[1., 0., 4.], [0., 1., 5.], [0., 0., 1.]], 'translation': [2, 3], 'ndim': 2, 'batchsize': 3, - 'dtype': torch.float}, - {'matrix': [[1., 0., 4.], [0., 1., 5.]], 'translation': [2, 3], 'ndim': 2, 'batchsize': 3, - 'dtype': torch.float} - - ] - expectations = [ - torch.tensor([[[1., 0., 2.], [0., 1., 3.]], - [[1., 0., 2.], [0., 1., 3.]], - [[1., 0., 2.], [0., 1., 3.]]]), - torch.tensor([[[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]]]), - torch.tensor([[[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]]]) - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = assemble_matrix_if_necessary(**inp, degree=False, - device='cpu', scale=None, rotation=None, - image_transform=False).to(exp.dtype) - self.assertTrue(torch.allclose(res, exp, atol=1e-6)) - - with self.assertRaises(ValueError): - assemble_matrix_if_necessary(matrix=[1, 2, 3, 4, 5], scale=None, - rotation=None, translation=None, - degree=False, dtype=torch.float, - device='cpu', batchsize=1, ndim=2) - if __name__ == '__main__': unittest.main() diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index ff48165d..fbbd8f66 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -36,6 +36,34 @@ def test_affine(self): self.assertEqual(sample['label'], result['label']) + def test_affine_assemble_matrix(self): + matrices = [ + [[1., 0.], [0., 1.]], + [[1., 0., 1.], [0., 1., 1.]], + [[1., 0., 1.], [0., 1., 1.], [0., 0., 1.]], + None, + [0., 1., 1., 0.] + ] + expected_matrices = [ + torch.tensor([[1., 0., 0.], [0., 1., 0.]])[None], + torch.tensor([[1., 0., 1.], [0., 1., 1.]])[None], + torch.tensor([[1., 0., 1.], [0., 1., 1.]])[None], + None, + None, + ] + value_error = [False, False, False, True, True] + batch = {"data": torch.zeros(1, 1, 10, 10)} + + for matrix, expected, ve in zip(matrices, expected_matrices, value_error): + with self.subTest(matrix=matrix, expected=expected): + trafo = Affine(matrix=matrix) + if ve: + with self.assertRaises(ValueError): + assembled = trafo.assemble_matrix(**batch) + else: + assembled = trafo.assemble_matrix(**batch) + self.assertTrue(expected.allclose(assembled)) + def test_affine_stacking(self): affines = [ Affine(scale=1), From 49383cc40da2846f6830c6d047d4602bc5e0ee72 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 23 Feb 2020 12:20:48 +0100 Subject: [PATCH 14/20] update tests --- rising/transforms/affine.py | 200 +++++++++++---------- tests/transforms/functional/test_affine.py | 48 ----- 2 files changed, 101 insertions(+), 147 deletions(-) diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index a84c3bae..60548163 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -210,6 +210,103 @@ def __radd__(self, other): **other.kwargs) +class StackedAffine(Affine): + def __init__( + self, + *transforms: Union[Affine, Sequence[ + Union[Sequence[Affine], Affine]]], + keys: Sequence = ('data',), + grad: bool = False, + output_size: tuple = None, + adjust_size: bool = False, + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + **kwargs): + """ + Class Performing an Affine Transformation on a given sample dict. + The transformation will be applied to all the dict-entries specified + in :attr:`keys`. + + Parameters + ---------- + transforms : sequence of Affines + the transforms to stack. Each transform must have a function + called ``assemble_matrix``, which is called to dynamically + assemble stacked matrices. Afterwards these transformations are + stacked by matrix-multiplication to only perform a single + interpolation + keys: Sequence + keys which should be augmented + grad: bool + enable gradient computation inside transformation + output_size : Iterable + if given, this will be the resulting image size. + Defaults to ``None`` + adjust_size : bool + if True, the resulting image size will be calculated dynamically + to ensure that the whole image fits. + interpolation_mode : str + interpolation mode to calculate output values + 'bilinear' | 'nearest'. Default: 'bilinear' + padding_mode : + padding mode for outside grid values + 'zeros' | 'border' | 'reflection'. Default: 'zeros' + align_corners : Geometrically, we consider the pixels of the input as + squares rather than points. If set to True, the extrema (-1 and 1) + are considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + **kwargs : + additional keyword arguments passed to the affine transform + + """ + + if isinstance(transforms, (tuple, list)): + if isinstance(transforms[0], (tuple, list)): + transforms = transforms[0] + + # ensure trafos are Affines and not raw matrices + transforms = tuple( + [trafo if isinstance(trafo, Affine) else Affine(matrix=trafo) + for trafo in transforms]) + + super().__init__(keys=keys, grad=grad, + output_size=output_size, adjust_size=adjust_size, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + **kwargs) + + self.transforms = transforms + + def assemble_matrix(self, **data) -> torch.Tensor: + """ + Handles the matrix assembly and stacking + + Parameters + ---------- + **data : + the data to be transformed. Will be used to determine batchsize, + dimensionality, dtype and device + + Returns + ------- + torch.Tensor + the (batched) transformation matrix + + """ + whole_trafo = None + for trafo in self.transforms: + matrix = matrix_to_homogeneous(trafo.assemble_matrix(**data)) + if whole_trafo is None: + whole_trafo = matrix + else: + whole_trafo = torch.bmm(whole_trafo, matrix) + return matrix_to_cartesian(whole_trafo) + + class BaseAffine(Affine): def __init__(self, scale: AffineParamType = None, @@ -346,7 +443,8 @@ def __init__(self, **kwargs): """ Class Performing a Rotation-OnlyAffine Transformation on a given - sample dict. + sample dict. The rotation is applied in consecutive order: + rot axis 0 -> rot axis 1 -> rot axis 2 The transformation will be applied to all the dict-entries specified in :attr:`keys`. @@ -427,7 +525,8 @@ def __init__(self, Parameters ---------- translation : torch.Tensor, int, float - the translation offset(s). Supported are: + the translation offset(s). The translation unit can be specified. + Supported are: * a single parameter (as float or int), which will be replicated for all dimensions and batch samples * a parameter per sample, which will be @@ -581,103 +680,6 @@ def __init__(self, **kwargs) -class StackedAffine(Affine): - def __init__( - self, - *transforms: Union[Affine, Sequence[ - Union[Sequence[Affine], Affine]]], - keys: Sequence = ('data',), - grad: bool = False, - output_size: tuple = None, - adjust_size: bool = False, - interpolation_mode: str = 'bilinear', - padding_mode: str = 'zeros', - align_corners: bool = False, - **kwargs): - """ - Class Performing an Affine Transformation on a given sample dict. - The transformation will be applied to all the dict-entries specified - in :attr:`keys`. - - Parameters - ---------- - transforms : sequence of Affines - the transforms to stack. Each transform must have a function - called ``assemble_matrix``, which is called to dynamically - assemble stacked matrices. Afterwards these transformations are - stacked by matrix-multiplication to only perform a single - interpolation - keys: Sequence - keys which should be augmented - grad: bool - enable gradient computation inside transformation - output_size : Iterable - if given, this will be the resulting image size. - Defaults to ``None`` - adjust_size : bool - if True, the resulting image size will be calculated dynamically - to ensure that the whole image fits. - interpolation_mode : str - interpolation mode to calculate output values - 'bilinear' | 'nearest'. Default: 'bilinear' - padding_mode : - padding mode for outside grid values - 'zeros' | 'border' | 'reflection'. Default: 'zeros' - align_corners : Geometrically, we consider the pixels of the input as - squares rather than points. If set to True, the extrema (-1 and 1) - are considered as referring to the center points of the input’s - corner pixels. If set to False, they are instead considered as - referring to the corner points of the input’s corner pixels, - making the sampling more resolution agnostic. - **kwargs : - additional keyword arguments passed to the affine transform - - """ - - if isinstance(transforms, (tuple, list)): - if isinstance(transforms[0], (tuple, list)): - transforms = transforms[0] - - # ensure trafos are Affines and not raw matrices - transforms = tuple( - [trafo if isinstance(trafo, Affine) else Affine(matrix=trafo) - for trafo in transforms]) - - super().__init__(keys=keys, grad=grad, - output_size=output_size, adjust_size=adjust_size, - interpolation_mode=interpolation_mode, - padding_mode=padding_mode, - align_corners=align_corners, - **kwargs) - - self.transforms = transforms - - def assemble_matrix(self, **data) -> torch.Tensor: - """ - Handles the matrix assembly and stacking - - Parameters - ---------- - **data : - the data to be transformed. Will be used to determine batchsize, - dimensionality, dtype and device - - Returns - ------- - torch.Tensor - the (batched) transformation matrix - - """ - whole_trafo = None - for trafo in self.transforms: - matrix = matrix_to_homogeneous(trafo.assemble_matrix(**data)) - if whole_trafo is None: - whole_trafo = matrix - else: - whole_trafo = torch.bmm(whole_trafo, matrix) - return matrix_to_cartesian(whole_trafo) - - class Resize(Scale): def __init__(self, size: Union[int, Iterable], diff --git a/tests/transforms/functional/test_affine.py b/tests/transforms/functional/test_affine.py index 544b9b7d..57f24457 100644 --- a/tests/transforms/functional/test_affine.py +++ b/tests/transforms/functional/test_affine.py @@ -141,7 +141,6 @@ def test_create_scale(self): {'scale': 2, 'batchsize': 2, 'ndim': 2}, {'scale': [2, 3], 'batchsize': 3, 'ndim': 2}, {'scale': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, - # {'scale': [[2, 3], [4, 5]], 'batchsize': 3, 'ndim': 2}, ] expectations = [ @@ -155,9 +154,6 @@ def test_create_scale(self): torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], [[3., 0., 0.], [0., 3., 0.], [0., 0., 1.]], [[4., 0., 0.], [0., 4., 0.], [0., 0., 1.]]]), - # torch.tensor([[[2, 3, 0], [4, 5, 0], [0, 0, 1]], - # [[2, 3, 0], [4, 5, 0], [0, 0, 1]], - # [[2, 3, 0], [4, 5, 0], [0, 0, 1]]]) ] for inp, exp in zip(inputs, expectations): @@ -174,15 +170,6 @@ def test_create_translation(self): {'offset': 2, 'batchsize': 2, 'ndim': 2}, {'offset': [2, 3], 'batchsize': 3, 'ndim': 2}, {'offset': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, - # {'offset': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - # [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - # [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], - # 'batchsize': 3, 'ndim': 2}, - # {'offset': [[[1, 2, 3], [4, 5, 6]], - # [[10, 11, 12], [13, 14, 15]], - # [[19, 20, 21], [22, 23, 24]]], - # 'batchsize': 3, 'ndim': 2} - ] expectations = [ @@ -196,13 +183,6 @@ def test_create_translation(self): torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], [[1, 0, 3], [0, 1, 3], [0, 0, 1]], [[1, 0, 4], [0, 1, 4], [0, 0, 1]]]), - # torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - # [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - # [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]), - # torch.tensor([[[1, 2, 3], [4, 5, 6], [0, 0, 1]], - # [[10, 11, 12], [13, 14, 15], [0, 0, 1]], - # [[19, 20, 21], [22, 23, 24], [0, 0, 1]]]) - ] for inp, exp in zip(inputs, expectations): @@ -217,19 +197,6 @@ def test_format_rotation(self): inputs = [ {'rotation': None, 'batchsize': 2, 'ndim': 3}, {'rotation': 0, 'degree': True, 'batchsize': 2, 'ndim': 2}, - - # TODO: update tests with multiple rotation - # {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 2, 'ndim': 3}, - # {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, 'ndim': 2}, - - # {'rotation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - # [[10, 11, 12], [13, 14, 15], [16, 17, 18]]], - # 'batchsize': 2, 'ndim': 2}, - # {'rotation': [[[1, 2, 3], [4, 5, 6]], - # [[10, 11, 12], [13, 14, 15]]], - # 'batchsize': 2, 'ndim': 2}, - # {'rotation': [[1, 2], [3, 4]], 'batchsize': 3, 'ndim': 2, 'degree': False} - ] expectations = [ torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], @@ -238,21 +205,6 @@ def test_format_rotation(self): [0., 0., 1., 0.], [0., 0., 0., 1.]]]), torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), - # torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], - # [0., 0., 1., 0.], [0., 0., 0., 1.]], - # [[1., 0., 0., 0.], [0., 1., 0., 0.], - # [0., 0., 1., 0.], [0., 0., 0., 1.]]]), - # torch.tensor([[[-1, 0, 0], [0, -1, 0], [0, 0, 1]], - # [[-1, 0, 0], [0, -1, 0], [0, 0, 1]], - # [[-1, 0, 0], [0, -1, 0], [0, 0, 1]]]), - - # torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], - # [[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]]]), - # torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], - # [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]]]), - # torch.tensor([[[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], - # [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], - # [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]]]) ] for inp, exp in zip(inputs, expectations): From 5fde88818c4911b36cc5cc443615ca817bf92f94 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 23 Feb 2020 15:38:13 +0100 Subject: [PATCH 15/20] GridTransform scatch --- rising/transforms/affine.py | 2 - rising/transforms/grid.py | 89 +++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 rising/transforms/grid.py diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index 60548163..0b113d33 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -260,9 +260,7 @@ def __init__( making the sampling more resolution agnostic. **kwargs : additional keyword arguments passed to the affine transform - """ - if isinstance(transforms, (tuple, list)): if isinstance(transforms[0], (tuple, list)): transforms = transforms[0] diff --git a/rising/transforms/grid.py b/rising/transforms/grid.py new file mode 100644 index 00000000..167ec2c2 --- /dev/null +++ b/rising/transforms/grid.py @@ -0,0 +1,89 @@ +from typing import Sequence, Union, Dict, Tuple + +import torch + +from abc import abstractmethod +from torch import Tensor + +from rising.transforms import AbstractTransform +from rising.utils.affine import get_batched_eye, matrix_to_homogeneous + + +__all__ = ["GridTransform", "StackedGridTransform"] + + +class GridTransform(AbstractTransform): + def __init__(self, + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False,): + super().__init__(grad=grad) + self.keys = keys + self.interpolation_mode = interpolation_mode + self.padding_mode = padding_mode + self.align_corners = align_corners + + self.grid: Dict[Tuple, Tensor] = None + + def forward(self, **data) -> dict: + if self.grid is None: + self.grid = self.create_grid([data[key].shape for key in self.keys]) + + self.grid = self.augment_grid(self.grid) + + for key in self.keys: + _grid = self.grid[tuple(data[key].shape)] + + data[key] = torch.nn.functional.grid_sample( + data[key], _grid, mode=self.interpolation_mode, + padding_mode=self.padding_mode, align_corners=self.align_corners) + self.grid = None + return data + + @abstractmethod + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + raise NotImplementedError + + def create_grid(self, input_size: Sequence[Sequence[int]], matrix: Tensor = None) -> \ + Dict[Tuple, Tensor]: + if matrix is None: + matrix = get_batched_eye(batchsize=input_size[0][0], ndim=len(input_size[0]) - 2) + matrix = matrix_to_homogeneous(matrix)[:, :-1] + + grid = {} + for size in input_size: + if tuple(size) not in grid: + grid[tuple(size)] = torch.nn.functional.affine_grid( + matrix, size=input_size, align_corners=self.align_corners) + return grid + + def __add__(self, other): + if not isinstance(other, GridTransform): + raise ValueError("Concatenation is only supported for grid transforms.") + return StackedGridTransform(self, other) + + def __radd__(self, other): + if not isinstance(other, GridTransform): + raise ValueError("Concatenation is only supported for grid transforms.") + return StackedGridTransform(other, self) + + +class StackedGridTransform(GridTransform): + def __init__(self, *transforms: Union[GridTransform, Sequence[GridTransform]]): + super().__init__(keys=None, interpolation_mode=None, padding_mode=None, + align_corners=None) + if isinstance(transforms, (tuple, list)): + if isinstance(transforms[0], (tuple, list)): + transforms = transforms[0] + self.transforms = transforms + + def create_grid(self, input_size: Sequence[Sequence[int]], matrix: Tensor = None) -> \ + Dict[Tuple, Tensor]: + return self.transforms[0].create_grid(input_size=input_size, matrix=matrix) + + def augment_grid(self, grid: Tensor) -> Tensor: + for transform in self.transforms: + grid = transform.augment_grid(grid) + return grid From e2a429c69c4e6ba77bbe5a76afe457ca031bb583 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 23 Feb 2020 16:59:24 +0100 Subject: [PATCH 16/20] transform affine transforms into grid transforms --- rising/transforms/affine.py | 228 ++++++++++++--------- rising/transforms/functional/affine.py | 53 +++-- rising/transforms/grid.py | 13 +- tests/transforms/functional/test_affine.py | 22 +- tests/transforms/test_affine.py | 28 +-- 5 files changed, 190 insertions(+), 154 deletions(-) diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index 0b113d33..d8933d6f 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -1,8 +1,9 @@ import torch -from typing import Sequence, Union, Iterable +from torch import Tensor +from typing import Sequence, Union, Iterable, Dict, Tuple -from rising.transforms.abstract import BaseTransform -from rising.transforms.functional.affine import affine_image_transform, \ +from rising.transforms.grid import GridTransform +from rising.transforms.functional.affine import create_affine_grid, \ AffineParamType, parametrize_matrix from rising.utils.affine import matrix_to_homogeneous, matrix_to_cartesian from rising.utils.checktype import check_scalar @@ -18,9 +19,9 @@ ] -class Affine(BaseTransform): +class Affine(GridTransform): def __init__(self, - matrix: Union[torch.Tensor, Sequence[Sequence[float]]] = None, + matrix: Union[Tensor, Sequence[Sequence[float]]] = None, keys: Sequence = ('data',), grad: bool = False, output_size: tuple = None, @@ -36,7 +37,7 @@ def __init__(self, Parameters ---------- - matrix : torch.Tensor, optional + matrix : Tensor, optional if given, overwrites the parameters for :param:`scale`, :param:rotation` and :param:`translation`. Should be a matrix of shape [(BATCHSIZE,) NDIM, NDIM(+1)] @@ -64,43 +65,46 @@ def __init__(self, referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. **kwargs : - additional keyword arguments passed to the affine transform + additional keyword arguments passed to grid sample """ - super().__init__(augment_fn=affine_image_transform, - keys=keys, - grad=grad, + super().__init__(keys=keys, grad=grad, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, **kwargs) self.matrix = matrix self.output_size = output_size self.adjust_size = adjust_size - self.interpolation_mode = interpolation_mode - self.padding_mode = padding_mode - self.align_corners = align_corners - def assemble_matrix(self, **data) -> torch.Tensor: + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: """ Assembles the matrix (and takes care of batching and having it on the right device and in the correct dtype and dimensionality). Parameters ---------- - **data : - the data to be transformed. Will be used to determine batchsize, - dimensionality, dtype and device + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid Returns ------- - torch.Tensor + Tensor the (batched) transformation matrix """ if self.matrix is None: raise ValueError("Matrix needs to be initialized or overwritten.") if not torch.is_tensor(self.matrix): - self.matrix = torch.tensor(self.matrix) - self.matrix = self.matrix.to(data[self.keys[0]]) + self.matrix = Tensor(self.matrix) + self.matrix = self.matrix.to(device=device, dtype=dtype) - batchsize = data[self.keys[0]].shape[0] - ndim = len(data[self.keys[0]].shape) - 2 # channel and batch dim + batchsize = batch_shape[0] + ndim = len(batch_shape) - 2 # channel and batch dim # batch dimension missing -> Replicate for each sample in batch if len(self.matrix.shape) == 2: @@ -118,34 +122,19 @@ def assemble_matrix(self, **data) -> torch.Tensor: str(tuple(self.matrix.shape)), str((batchsize, ndim, ndim + 1)))) - def forward(self, **data) -> dict: - """ - Assembles the matrix and applies it to the specified sample-entities. - - Parameters - ---------- - **data : - the data to transform - - Returns - ------- - dict - dictionary containing the transformed data - """ - matrix = self.assemble_matrix(**data) - - for key in self.keys: - data[key] = self.augment_fn( - data[key], matrix_batch=matrix, - output_size=self.output_size, - adjust_size=self.adjust_size, - interpolation_mode=self.interpolation_mode, - padding_mode=self.padding_mode, - align_corners=self.align_corners, - **self.kwargs - ) + def create_grid(self, input_size: Sequence[Sequence[int]], + matrix: Tensor = None) -> Dict[Tuple, Tensor]: + grid = {} + for size in input_size: + if tuple(size) not in grid: + grid[tuple(size)] = create_affine_grid( + size, self.assemble_matrix(size), output_size=self.output_size, + adjust_size=self.adjust_size, align_corners=self.align_corners, + ) + return grid - return data + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + return grid def __add__(self, other): """ @@ -154,7 +143,7 @@ def __add__(self, other): Parameters ---------- - other : torch.Tensor, Affine + other : Tensor, Affine the other transformation Returns @@ -162,7 +151,7 @@ def __add__(self, other): StackedAffine a stacked affine transformation """ - if not isinstance(other, Affine): + if not isinstance(other, GridTransform): other = Affine(matrix=other, keys=self.keys, grad=self.grad, output_size=self.output_size, adjust_size=self.adjust_size, @@ -171,12 +160,15 @@ def __add__(self, other): align_corners=self.align_corners, **self.kwargs) - return StackedAffine(self, other, keys=self.keys, grad=self.grad, - output_size=self.output_size, - adjust_size=self.adjust_size, - interpolation_mode=self.interpolation_mode, - padding_mode=self.padding_mode, - align_corners=self.align_corners, **self.kwargs) + if isinstance(other, Affine): + return StackedAffine(self, other, keys=self.keys, grad=self.grad, + output_size=self.output_size, + adjust_size=self.adjust_size, + interpolation_mode=self.interpolation_mode, + padding_mode=self.padding_mode, + align_corners=self.align_corners, **self.kwargs) + else: + return super().__add__(other) def __radd__(self, other): """ @@ -185,7 +177,7 @@ def __radd__(self, other): Parameters ---------- - other : torch.Tensor, Affine + other : Tensor, Affine the other transformation Returns @@ -193,7 +185,7 @@ def __radd__(self, other): StackedAffine a stacked affine transformation """ - if not isinstance(other, Affine): + if not isinstance(other, GridTransform): other = Affine(matrix=other, keys=self.keys, grad=self.grad, output_size=self.output_size, adjust_size=self.adjust_size, @@ -201,13 +193,16 @@ def __radd__(self, other): padding_mode=self.padding_mode, align_corners=self.align_corners, **self.kwargs) - return StackedAffine(other, self, grad=other.grad, - output_size=other.output_size, - adjust_size=other.adjust_size, - interpolation_mode=other.interpolation_mode, - padding_mode=other.padding_mode, - align_corners=other.align_corners, - **other.kwargs) + if isinstance(other, Affine): + return StackedAffine(other, self, grad=other.grad, + output_size=other.output_size, + adjust_size=other.adjust_size, + interpolation_mode=other.interpolation_mode, + padding_mode=other.padding_mode, + align_corners=other.align_corners, + **other.kwargs) + else: + return super().__add__(other) class StackedAffine(Affine): @@ -279,25 +274,34 @@ def __init__( self.transforms = transforms - def assemble_matrix(self, **data) -> torch.Tensor: + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: """ Handles the matrix assembly and stacking Parameters ---------- - **data : - the data to be transformed. Will be used to determine batchsize, - dimensionality, dtype and device + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid Returns ------- - torch.Tensor + Tensor the (batched) transformation matrix """ whole_trafo = None for trafo in self.transforms: - matrix = matrix_to_homogeneous(trafo.assemble_matrix(**data)) + matrix = matrix_to_homogeneous(trafo.assemble_matrix( + batch_shape=batch_shape, device=device, dtype=dtype + )) if whole_trafo is None: whole_trafo = matrix else: @@ -328,7 +332,7 @@ def __init__(self, Parameters ---------- - scale : torch.Tensor, int, float, optional + scale : Tensor, int, float, optional the scale factor(s). Supported are: * a single parameter (as float or int), which will be replicated for all dimensions and batch samples @@ -338,7 +342,7 @@ def __init__(self, batch samples * a parameter per sampler per dimension None will be treated as a scaling factor of 1 - rotation : torch.Tensor, int, float, optional + rotation : Tensor, int, float, optional the rotation factor(s). The rotation is performed in consecutive order axis0 -> axis1 (-> axis 2). Supported are: * a single parameter (as float or int), which will be replicated @@ -349,7 +353,7 @@ def __init__(self, batch samples * a parameter per sampler per dimension None will be treated as a rotation factor of 1 - translation : torch.Tensor, int, float + translation : Tensor, int, float the translation offset(s) relative to image (should be in the range [0, 1]). Supported are: * a single parameter (as float or int), which will be replicated @@ -399,26 +403,31 @@ def __init__(self, self.degree = degree self.image_transform = image_transform - def assemble_matrix(self, **data) -> torch.Tensor: + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: """ Assembles the matrix (and takes care of batching and having it on the right device and in the correct dtype and dimensionality). Parameters ---------- - **data : - the data to be transformed. Will be used to determine batchsize, - dimensionality, dtype and device + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid Returns ------- - torch.Tensor + Tensor the (batched) transformation matrix """ - batchsize = data[self.keys[0]].shape[0] - ndim = len(data[self.keys[0]].shape) - 2 # channel and batch dim - device = data[self.keys[0]].device - dtype = data[self.keys[0]].dtype + batchsize = batch_shape[0] + ndim = len(batch_shape) - 2 # channel and batch dim self.matrix = parametrize_matrix( scale=self.scale, rotation=self.rotation, translation=self.translation, @@ -448,7 +457,7 @@ def __init__(self, Parameters ---------- - rotation : torch.Tensor, int, float, optional + rotation : Tensor, int, float, optional the rotation factor(s). Supported are: * a single parameter (as float or int), which will be replicated for all dimensions and batch samples @@ -522,7 +531,7 @@ def __init__(self, Parameters ---------- - translation : torch.Tensor, int, float + translation : Tensor, int, float the translation offset(s). The translation unit can be specified. Supported are: * a single parameter (as float or int), which will be replicated @@ -578,25 +587,33 @@ def __init__(self, **kwargs) self.unit = unit - def assemble_matrix(self, **data) -> torch.Tensor: + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: """ Assembles the matrix (and takes care of batching and having it on the right device and in the correct dtype and dimensionality). Parameters ---------- - **data : - the data to be transformed. Will be used to determine batchsize, - dimensionality, dtype and device + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid Returns ------- - torch.Tensor + Tensor the (batched) transformation matrix [N, NDIM, NDIM] """ - matrix = super().assemble_matrix(**data) + matrix = super().assemble_matrix(batch_shape=batch_shape, + device=device, dtype=dtype) if self.unit.lower() == 'pixel': - img_size = torch.tensor(data[self.keys[0]].shape[2:]).to(matrix) + img_size = torch.tensor(batch_shape[2:]).to(matrix) matrix[..., -1] = matrix[..., -1] / img_size return matrix @@ -620,7 +637,7 @@ def __init__(self, Parameters ---------- - scale : torch.Tensor, int, float, optional + scale : Tensor, int, float, optional the scale factor(s). Supported are: * a single parameter (as float or int), which will be replicated for all dimensions and batch samples @@ -738,24 +755,32 @@ def __init__(self, align_corners=align_corners, **kwargs) - def assemble_matrix(self, **data) -> torch.Tensor: + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: + """ Handles the matrix assembly and calculates the scale factors for resizing Parameters ---------- - **data : - the data to be transformed. Will be used to determine batchsize, - dimensionality, dtype and device + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid Returns ------- - torch.Tensor + Tensor the (batched) transformation matrix """ - curr_img_size = data[self.keys[0]].shape[2:] + curr_img_size = batch_shape[2:] was_scalar = check_scalar(self.output_size) @@ -765,7 +790,8 @@ def assemble_matrix(self, **data) -> torch.Tensor: self.scale = [self.output_size[i] / curr_img_size[-i] for i in range(len(curr_img_size))] - matrix = super().assemble_matrix(**data) + matrix = super().assemble_matrix(batch_shape=batch_shape, + device=device, dtype=dtype) if was_scalar: self.output_size = self.output_size[0] diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index 943f2107..8f98d211 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -9,7 +9,7 @@ __all__ = [ - 'affine_image_transform', + 'create_affine_grid', 'affine_point_transform', "create_rotation", "create_scale", @@ -434,20 +434,21 @@ def affine_point_transform(point_batch: torch.Tensor, return points_to_cartesian(transformed_points) -def affine_image_transform(image_batch: torch.Tensor, - matrix_batch: torch.Tensor, - output_size: tuple = None, - adjust_size: bool = False, - interpolation_mode: str = 'bilinear', - padding_mode: str = 'zeros', - align_corners: bool = False) -> torch.Tensor: +def create_affine_grid(batch_shape: Sequence[int], + matrix_batch: torch.Tensor, + output_size: tuple = None, + adjust_size: bool = False, + align_corners: bool = False, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> torch.Tensor: """ Performs an affine transformation on a batch of images Parameters ---------- - image_batch : torch.Tensor - the batch to transform. Should have shape of [N, C, NDIM] + batch_shape : Sequence[int] + shape of batch matrix_batch : torch.Tensor a batch of affine matrices with shape [N, NDIM, NDIM+1] output_size : Iterable @@ -455,18 +456,16 @@ def affine_image_transform(image_batch: torch.Tensor, adjust_size : bool if True, the resulting image size will be calculated dynamically to ensure that the whole image fits. - interpolation_mode : str - interpolation mode to calculate output values 'bilinear' | 'nearest'. - Default: 'bilinear' - padding_mode : - padding mode for outside grid values - 'zeros' | 'border' | 'reflection'. Default: 'zeros' align_corners : Geometrically, we consider the pixels of the input as squares rather than points. If set to True, the extrema (-1 and 1) are considered as referring to the center points of the input’s corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid Returns ------- @@ -489,9 +488,9 @@ def affine_image_transform(image_batch: torch.Tensor, # add batch dimension if necessary if len(matrix_batch.shape) < 3: matrix_batch = matrix_batch[None, ...].expand( - image_batch.size(0), -1, -1).clone() + batch_shape[0], -1, -1).clone() - image_size = image_batch.shape[2:] + image_size = batch_shape[2:] if output_size is not None: if check_scalar(output_size): @@ -508,19 +507,15 @@ def affine_image_transform(image_batch: torch.Tensor, else: new_size = image_size - if len(image_size) < len(image_batch.shape): - missing_dims = len(image_batch.shape) - len(image_size) - new_size = (*image_batch.shape[:missing_dims], *new_size) - - matrix_batch = matrix_batch.to(image_batch) + if len(image_size) < len(batch_shape): + missing_dims = len(batch_shape) - len(image_size) + new_size = (*batch_shape[:missing_dims], *new_size) - grid = torch.nn.functional.affine_grid(matrix_batch, size=new_size, - align_corners=align_corners) + matrix_batch = matrix_batch.to(device=device, dtype=dtype) - return torch.nn.functional.grid_sample(image_batch, grid, - mode=interpolation_mode, - padding_mode=padding_mode, - align_corners=align_corners) + grid = torch.nn.functional.affine_grid( + matrix_batch, size=new_size, align_corners=align_corners) + return grid def _check_new_img_size(curr_img_size, matrix: torch.Tensor, diff --git a/rising/transforms/grid.py b/rising/transforms/grid.py index 167ec2c2..5ae5dd52 100644 --- a/rising/transforms/grid.py +++ b/rising/transforms/grid.py @@ -18,12 +18,15 @@ def __init__(self, interpolation_mode: str = 'bilinear', padding_mode: str = 'zeros', align_corners: bool = False, - grad: bool = False,): + grad: bool = False, + **kwargs, + ): super().__init__(grad=grad) self.keys = keys self.interpolation_mode = interpolation_mode self.padding_mode = padding_mode self.align_corners = align_corners + self.kwargs = kwargs self.grid: Dict[Tuple, Tensor] = None @@ -35,10 +38,12 @@ def forward(self, **data) -> dict: for key in self.keys: _grid = self.grid[tuple(data[key].shape)] + _grid = _grid.to(data[key]) data[key] = torch.nn.functional.grid_sample( data[key], _grid, mode=self.interpolation_mode, - padding_mode=self.padding_mode, align_corners=self.align_corners) + padding_mode=self.padding_mode, align_corners=self.align_corners, + **self.kwargs) self.grid = None return data @@ -46,8 +51,8 @@ def forward(self, **data) -> dict: def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: raise NotImplementedError - def create_grid(self, input_size: Sequence[Sequence[int]], matrix: Tensor = None) -> \ - Dict[Tuple, Tensor]: + def create_grid(self, input_size: Sequence[Sequence[int]], + matrix: Tensor = None) -> Dict[Tuple, Tensor]: if matrix is None: matrix = get_batched_eye(batchsize=input_size[0][0], ndim=len(input_size[0]) - 2) matrix = matrix_to_homogeneous(matrix)[:, :-1] diff --git a/tests/transforms/functional/test_affine.py b/tests/transforms/functional/test_affine.py index 57f24457..e920cee0 100644 --- a/tests/transforms/functional/test_affine.py +++ b/tests/transforms/functional/test_affine.py @@ -1,7 +1,7 @@ import unittest import torch from rising.transforms.functional.affine import _check_new_img_size, \ - affine_point_transform, affine_image_transform, parametrize_matrix, \ + affine_point_transform, create_affine_grid, parametrize_matrix, \ create_rotation, create_translation, create_scale from rising.utils.affine import matrix_to_homogeneous, matrix_to_cartesian @@ -121,18 +121,24 @@ def test_affine_image_trafo(self): output_size=output_size): if output_size is not None and adjust_size: with self.assertWarns(UserWarning): - result = affine_image_transform( - image_batch=image_batch, + grid = create_affine_grid( + batch_shape=image_batch.shape, matrix_batch=matrix, output_size=output_size, - adjust_size=adjust_size) + adjust_size=adjust_size, + device=image_batch.device, + dtype=image_batch.dtype, + ) else: - result = affine_image_transform( - image_batch=image_batch, + grid = create_affine_grid( + batch_shape=image_batch.shape, matrix_batch=matrix, output_size=output_size, - adjust_size=adjust_size) - + adjust_size=adjust_size, + device=image_batch.device, + dtype=image_batch.dtype, + ) + result = torch.nn.functional.grid_sample(image_batch, grid) self.assertTupleEqual(result.shape[2:], target_size) def test_create_scale(self): diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index fbbd8f66..c133d35c 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -52,16 +52,18 @@ def test_affine_assemble_matrix(self): None, ] value_error = [False, False, False, True, True] - batch = {"data": torch.zeros(1, 1, 10, 10)} + batch = torch.zeros(1, 1, 10, 10) for matrix, expected, ve in zip(matrices, expected_matrices, value_error): with self.subTest(matrix=matrix, expected=expected): trafo = Affine(matrix=matrix) if ve: with self.assertRaises(ValueError): - assembled = trafo.assemble_matrix(**batch) + assembled = trafo.assemble_matrix( + batch.shape, device=batch.device, dtype=batch.dtype) else: - assembled = trafo.assemble_matrix(**batch) + assembled = trafo.assemble_matrix( + batch.shape, device=batch.device, dtype=batch.dtype) self.assertTrue(expected.allclose(assembled)) def test_affine_stacking(self): @@ -92,9 +94,10 @@ def test_stacked_transformation_assembly(self): second_matrix = torch.tensor([[[4., 0., 3.], [0., 5., 4.]]]) trafo = StackedAffine([first_matrix, second_matrix]) - sample = {'data': torch.rand(1, 3, 25, 25)} + sample = torch.rand(1, 3, 25, 25) - matrix = trafo.assemble_matrix(**sample) + matrix = trafo.assemble_matrix(sample.shape, + dtype=sample.dtype, device=sample.device) target_matrix = matrix_to_cartesian( torch.bmm( @@ -128,13 +131,14 @@ def test_affine_subtypes(self): def test_translation_assemble_matrix_with_pixel(self): trafo = Translate([1, 10, 100], unit='pixel') - sample = {'data': torch.rand(3, 3, 100, 100)} - expected = torch.tensor([[1., 0., -0.01], [0., 1., -0.01], - [1., 0., -0.1], [0., 1., -0.1], - [1., 0., -1.], [0., 1., -1.]]) - - trafo.assemble_matrix(**sample) - self.assertTrue(expected.allclose(expected)) + sample = torch.rand(3, 3, 100, 100) + expected = torch.tensor([[[1., 0., -0.01], [0., 1., -0.01]], + [[1., 0., -0.1], [0., 1., -0.1]], + [[1., 0., -1.], [0., 1., -1.]]]) + + out = trafo.assemble_matrix( + sample.shape, device=sample.device, dtype=sample.dtype) + self.assertTrue(expected.allclose(out)) if __name__ == '__main__': From 8cd63390b8d404201336f30762ef31a301ff72f0 Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Sun, 23 Feb 2020 16:00:23 +0000 Subject: [PATCH 17/20] autopep8 fix --- rising/transforms/affine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index d8933d6f..be002381 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -130,7 +130,7 @@ def create_grid(self, input_size: Sequence[Sequence[int]], grid[tuple(size)] = create_affine_grid( size, self.assemble_matrix(size), output_size=self.output_size, adjust_size=self.adjust_size, align_corners=self.align_corners, - ) + ) return grid def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: @@ -760,7 +760,6 @@ def assemble_matrix(self, device: Union[torch.device, str] = None, dtype: Union[torch.dtype, str] = None, ) -> Tensor: - """ Handles the matrix assembly and calculates the scale factors for resizing From d509405183bb176ad2b5293b295590c7348d5368 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 23 Feb 2020 19:00:23 +0100 Subject: [PATCH 18/20] prototyping --- notebooks/transformations.ipynb | 426 ++++++++++++++++++++++++++- rising/transforms/__init__.py | 1 + rising/transforms/functional/crop.py | 44 ++- rising/transforms/grid.py | 101 ++++++- 4 files changed, 549 insertions(+), 23 deletions(-) diff --git a/notebooks/transformations.ipynb b/notebooks/transformations.ipynb index 2abe6d3d..94407d8a 100644 --- a/notebooks/transformations.ipynb +++ b/notebooks/transformations.ipynb @@ -18,14 +18,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2020-02-16T16:54:54.920459Z", "start_time": "2020-02-16T16:54:54.669509Z" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running test environment: False\n" + ] + } + ], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", @@ -63,14 +71,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2020-02-16T16:55:01.394975Z", "start_time": "2020-02-16T16:55:00.893340Z" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image shape (192, 192, 174)\n", + "Image shape (192, 192, 174)\n" + ] + } + ], "source": [ "import SimpleITK as sitk\n", "import numpy as np\n", @@ -90,14 +107,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2020-02-16T16:55:04.255613Z", "start_time": "2020-02-16T16:55:03.213336Z" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/micha/miniconda3/envs/phoenix/lib/python3.7/site-packages/qtpy/__init__.py:216: RuntimeWarning: Selected binding \"pyqt5\" could not be found, using \"pyside2\"\n", + " 'using \"{}\"'.format(initial_api, API), RuntimeWarning)\n" + ] + } + ], "source": [ "if TEST_ENV:\n", " def view_batch(batch):\n", @@ -112,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2020-02-16T16:55:54.082493Z", @@ -141,14 +167,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2020-02-16T16:55:06.109008Z", "start_time": "2020-02-16T16:55:06.069336Z" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transformed data shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed mask shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed data min: 0.0\n", + "Transformed data max: 502.0\n", + "Transformed data mean: 37.62009048461914\n" + ] + } + ], "source": [ "print(f\"Transformed data shape: {batch['data'].shape}\")\n", "print(f\"Transformed mask shape: {batch['mask'].shape}\")\n", @@ -159,14 +197,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2020-02-16T16:55:57.391117Z", "start_time": "2020-02-16T16:55:55.675294Z" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transformed data shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed mask shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed data min: 0.0\n", + "Transformed data max: 445.4209289550781\n", + "Transformed data mean: 110.88246154785156\n" + ] + } + ], "source": [ "trafo = Scale(1.5, adjust_size=False)\n", "transformed = apply_transform(trafo, batch)\n", @@ -205,6 +255,360 @@ "view_batch(transformed)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trafo = CenterCropGrid(100)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trafo = RandomDistortion(noise_type=\"normal\", noise_kwargs={\"mean\": 0.1, \"std\": 0.005})\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(0.4000)\n", + "tensor(0.0002)\n", + "{(1, 1, 192, 192, 174): tensor([[[[[-0.5966, -0.5969, -0.5969],\n", + " [-0.5927, -0.5999, -0.5999],\n", + " [-0.5887, -0.6029, -0.6029],\n", + " ...,\n", + " [ 0.5887, -0.6029, -0.6029],\n", + " [ 0.5927, -0.5999, -0.5999],\n", + " [ 0.5966, -0.5969, -0.5969]],\n", + "\n", + " [[-0.5993, -0.5934, -0.5996],\n", + " [-0.5954, -0.5964, -0.6027],\n", + " [-0.5914, -0.5994, -0.6057],\n", + " ...,\n", + " [ 0.5914, -0.5994, -0.6057],\n", + " [ 0.5954, -0.5964, -0.6027],\n", + " [ 0.5993, -0.5934, -0.5996]],\n", + "\n", + " [[-0.6020, -0.5898, -0.6024],\n", + " [-0.5981, -0.5927, -0.6054],\n", + " [-0.5940, -0.5957, -0.6084],\n", + " ...,\n", + " [ 0.5940, -0.5957, -0.6084],\n", + " [ 0.5981, -0.5927, -0.6054],\n", + " [ 0.6020, -0.5898, -0.6024]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.6020, 0.5898, -0.6024],\n", + " [-0.5981, 0.5927, -0.6054],\n", + " [-0.5940, 0.5957, -0.6084],\n", + " ...,\n", + " [ 0.5940, 0.5957, -0.6084],\n", + " [ 0.5981, 0.5927, -0.6054],\n", + " [ 0.6020, 0.5898, -0.6024]],\n", + "\n", + " [[-0.5993, 0.5934, -0.5996],\n", + " [-0.5954, 0.5964, -0.6027],\n", + " [-0.5914, 0.5994, -0.6057],\n", + " ...,\n", + " [ 0.5914, 0.5994, -0.6057],\n", + " [ 0.5954, 0.5964, -0.6027],\n", + " [ 0.5993, 0.5934, -0.5996]],\n", + "\n", + " [[-0.5966, 0.5969, -0.5969],\n", + " [-0.5927, 0.5999, -0.5999],\n", + " [-0.5887, 0.6029, -0.6029],\n", + " ...,\n", + " [ 0.5887, 0.6029, -0.6029],\n", + " [ 0.5927, 0.5999, -0.5999],\n", + " [ 0.5966, 0.5969, -0.5969]]],\n", + "\n", + "\n", + " [[[-0.5993, -0.5996, -0.5934],\n", + " [-0.5954, -0.6027, -0.5964],\n", + " [-0.5914, -0.6057, -0.5994],\n", + " ...,\n", + " [ 0.5914, -0.6057, -0.5994],\n", + " [ 0.5954, -0.6027, -0.5964],\n", + " [ 0.5993, -0.5996, -0.5934]],\n", + "\n", + " [[-0.6021, -0.5961, -0.5961],\n", + " [-0.5981, -0.5991, -0.5991],\n", + " [-0.5941, -0.6021, -0.6021],\n", + " ...,\n", + " [ 0.5941, -0.6021, -0.6021],\n", + " [ 0.5981, -0.5991, -0.5991],\n", + " [ 0.6021, -0.5961, -0.5961]],\n", + "\n", + " [[-0.6048, -0.5925, -0.5988],\n", + " [-0.6008, -0.5954, -0.6018],\n", + " [-0.5967, -0.5984, -0.6048],\n", + " ...,\n", + " [ 0.5967, -0.5984, -0.6048],\n", + " [ 0.6008, -0.5954, -0.6018],\n", + " [ 0.6048, -0.5925, -0.5988]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.6048, 0.5925, -0.5988],\n", + " [-0.6008, 0.5954, -0.6018],\n", + " [-0.5967, 0.5984, -0.6048],\n", + " ...,\n", + " [ 0.5967, 0.5984, -0.6048],\n", + " [ 0.6008, 0.5954, -0.6018],\n", + " [ 0.6048, 0.5925, -0.5988]],\n", + "\n", + " [[-0.6021, 0.5961, -0.5961],\n", + " [-0.5981, 0.5991, -0.5991],\n", + " [-0.5941, 0.6021, -0.6021],\n", + " ...,\n", + " [ 0.5941, 0.6021, -0.6021],\n", + " [ 0.5981, 0.5991, -0.5991],\n", + " [ 0.6021, 0.5961, -0.5961]],\n", + "\n", + " [[-0.5993, 0.5996, -0.5934],\n", + " [-0.5954, 0.6027, -0.5964],\n", + " [-0.5914, 0.6057, -0.5994],\n", + " ...,\n", + " [ 0.5914, 0.6057, -0.5994],\n", + " [ 0.5954, 0.6027, -0.5964],\n", + " [ 0.5993, 0.5996, -0.5934]]],\n", + "\n", + "\n", + " [[[-0.6020, -0.6024, -0.5898],\n", + " [-0.5981, -0.6054, -0.5927],\n", + " [-0.5940, -0.6084, -0.5957],\n", + " ...,\n", + " [ 0.5940, -0.6084, -0.5957],\n", + " [ 0.5981, -0.6054, -0.5927],\n", + " [ 0.6020, -0.6024, -0.5898]],\n", + "\n", + " [[-0.6048, -0.5988, -0.5925],\n", + " [-0.6008, -0.6018, -0.5954],\n", + " [-0.5967, -0.6048, -0.5984],\n", + " ...,\n", + " [ 0.5967, -0.6048, -0.5984],\n", + " [ 0.6008, -0.6018, -0.5954],\n", + " [ 0.6048, -0.5988, -0.5925]],\n", + "\n", + " [[-0.6075, -0.5951, -0.5951],\n", + " [-0.6035, -0.5981, -0.5981],\n", + " [-0.5994, -0.6011, -0.6011],\n", + " ...,\n", + " [ 0.5994, -0.6011, -0.6011],\n", + " [ 0.6035, -0.5981, -0.5981],\n", + " [ 0.6075, -0.5951, -0.5951]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.6075, 0.5951, -0.5951],\n", + " [-0.6035, 0.5981, -0.5981],\n", + " [-0.5994, 0.6011, -0.6011],\n", + " ...,\n", + " [ 0.5994, 0.6011, -0.6011],\n", + " [ 0.6035, 0.5981, -0.5981],\n", + " [ 0.6075, 0.5951, -0.5951]],\n", + "\n", + " [[-0.6048, 0.5988, -0.5925],\n", + " [-0.6008, 0.6018, -0.5954],\n", + " [-0.5967, 0.6048, -0.5984],\n", + " ...,\n", + " [ 0.5967, 0.6048, -0.5984],\n", + " [ 0.6008, 0.6018, -0.5954],\n", + " [ 0.6048, 0.5988, -0.5925]],\n", + "\n", + " [[-0.6020, 0.6024, -0.5898],\n", + " [-0.5981, 0.6054, -0.5927],\n", + " [-0.5940, 0.6084, -0.5957],\n", + " ...,\n", + " [ 0.5940, 0.6084, -0.5957],\n", + " [ 0.5981, 0.6054, -0.5927],\n", + " [ 0.6020, 0.6024, -0.5898]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-0.6020, -0.6024, 0.5898],\n", + " [-0.5981, -0.6054, 0.5927],\n", + " [-0.5940, -0.6084, 0.5957],\n", + " ...,\n", + " [ 0.5940, -0.6084, 0.5957],\n", + " [ 0.5981, -0.6054, 0.5927],\n", + " [ 0.6020, -0.6024, 0.5898]],\n", + "\n", + " [[-0.6048, -0.5988, 0.5925],\n", + " [-0.6008, -0.6018, 0.5954],\n", + " [-0.5967, -0.6048, 0.5984],\n", + " ...,\n", + " [ 0.5967, -0.6048, 0.5984],\n", + " [ 0.6008, -0.6018, 0.5954],\n", + " [ 0.6048, -0.5988, 0.5925]],\n", + "\n", + " [[-0.6075, -0.5951, 0.5951],\n", + " [-0.6035, -0.5981, 0.5981],\n", + " [-0.5994, -0.6011, 0.6011],\n", + " ...,\n", + " [ 0.5994, -0.6011, 0.6011],\n", + " [ 0.6035, -0.5981, 0.5981],\n", + " [ 0.6075, -0.5951, 0.5951]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.6075, 0.5951, 0.5951],\n", + " [-0.6035, 0.5981, 0.5981],\n", + " [-0.5994, 0.6011, 0.6011],\n", + " ...,\n", + " [ 0.5994, 0.6011, 0.6011],\n", + " [ 0.6035, 0.5981, 0.5981],\n", + " [ 0.6075, 0.5951, 0.5951]],\n", + "\n", + " [[-0.6048, 0.5988, 0.5925],\n", + " [-0.6008, 0.6018, 0.5954],\n", + " [-0.5967, 0.6048, 0.5984],\n", + " ...,\n", + " [ 0.5967, 0.6048, 0.5984],\n", + " [ 0.6008, 0.6018, 0.5954],\n", + " [ 0.6048, 0.5988, 0.5925]],\n", + "\n", + " [[-0.6020, 0.6024, 0.5898],\n", + " [-0.5981, 0.6054, 0.5927],\n", + " [-0.5940, 0.6084, 0.5957],\n", + " ...,\n", + " [ 0.5940, 0.6084, 0.5957],\n", + " [ 0.5981, 0.6054, 0.5927],\n", + " [ 0.6020, 0.6024, 0.5898]]],\n", + "\n", + "\n", + " [[[-0.5993, -0.5996, 0.5934],\n", + " [-0.5954, -0.6027, 0.5964],\n", + " [-0.5914, -0.6057, 0.5994],\n", + " ...,\n", + " [ 0.5914, -0.6057, 0.5994],\n", + " [ 0.5954, -0.6027, 0.5964],\n", + " [ 0.5993, -0.5996, 0.5934]],\n", + "\n", + " [[-0.6021, -0.5961, 0.5961],\n", + " [-0.5981, -0.5991, 0.5991],\n", + " [-0.5941, -0.6021, 0.6021],\n", + " ...,\n", + " [ 0.5941, -0.6021, 0.6021],\n", + " [ 0.5981, -0.5991, 0.5991],\n", + " [ 0.6021, -0.5961, 0.5961]],\n", + "\n", + " [[-0.6048, -0.5925, 0.5988],\n", + " [-0.6008, -0.5954, 0.6018],\n", + " [-0.5967, -0.5984, 0.6048],\n", + " ...,\n", + " [ 0.5967, -0.5984, 0.6048],\n", + " [ 0.6008, -0.5954, 0.6018],\n", + " [ 0.6048, -0.5925, 0.5988]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.6048, 0.5925, 0.5988],\n", + " [-0.6008, 0.5954, 0.6018],\n", + " [-0.5967, 0.5984, 0.6048],\n", + " ...,\n", + " [ 0.5967, 0.5984, 0.6048],\n", + " [ 0.6008, 0.5954, 0.6018],\n", + " [ 0.6048, 0.5925, 0.5988]],\n", + "\n", + " [[-0.6021, 0.5961, 0.5961],\n", + " [-0.5981, 0.5991, 0.5991],\n", + " [-0.5941, 0.6021, 0.6021],\n", + " ...,\n", + " [ 0.5941, 0.6021, 0.6021],\n", + " [ 0.5981, 0.5991, 0.5991],\n", + " [ 0.6021, 0.5961, 0.5961]],\n", + "\n", + " [[-0.5993, 0.5996, 0.5934],\n", + " [-0.5954, 0.6027, 0.5964],\n", + " [-0.5914, 0.6057, 0.5994],\n", + " ...,\n", + " [ 0.5914, 0.6057, 0.5994],\n", + " [ 0.5954, 0.6027, 0.5964],\n", + " [ 0.5993, 0.5996, 0.5934]]],\n", + "\n", + "\n", + " [[[-0.5966, -0.5969, 0.5969],\n", + " [-0.5927, -0.5999, 0.5999],\n", + " [-0.5887, -0.6029, 0.6029],\n", + " ...,\n", + " [ 0.5887, -0.6029, 0.6029],\n", + " [ 0.5927, -0.5999, 0.5999],\n", + " [ 0.5966, -0.5969, 0.5969]],\n", + "\n", + " [[-0.5993, -0.5934, 0.5996],\n", + " [-0.5954, -0.5964, 0.6027],\n", + " [-0.5914, -0.5994, 0.6057],\n", + " ...,\n", + " [ 0.5914, -0.5994, 0.6057],\n", + " [ 0.5954, -0.5964, 0.6027],\n", + " [ 0.5993, -0.5934, 0.5996]],\n", + "\n", + " [[-0.6020, -0.5898, 0.6024],\n", + " [-0.5981, -0.5927, 0.6054],\n", + " [-0.5940, -0.5957, 0.6084],\n", + " ...,\n", + " [ 0.5940, -0.5957, 0.6084],\n", + " [ 0.5981, -0.5927, 0.6054],\n", + " [ 0.6020, -0.5898, 0.6024]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.6020, 0.5898, 0.6024],\n", + " [-0.5981, 0.5927, 0.6054],\n", + " [-0.5940, 0.5957, 0.6084],\n", + " ...,\n", + " [ 0.5940, 0.5957, 0.6084],\n", + " [ 0.5981, 0.5927, 0.6054],\n", + " [ 0.6020, 0.5898, 0.6024]],\n", + "\n", + " [[-0.5993, 0.5934, 0.5996],\n", + " [-0.5954, 0.5964, 0.6027],\n", + " [-0.5914, 0.5994, 0.6057],\n", + " ...,\n", + " [ 0.5914, 0.5994, 0.6057],\n", + " [ 0.5954, 0.5964, 0.6027],\n", + " [ 0.5993, 0.5934, 0.5996]],\n", + "\n", + " [[-0.5966, 0.5969, 0.5969],\n", + " [-0.5927, 0.5999, 0.5999],\n", + " [-0.5887, 0.6029, 0.6029],\n", + " ...,\n", + " [ 0.5887, 0.6029, 0.6029],\n", + " [ 0.5927, 0.5999, 0.5999],\n", + " [ 0.5966, 0.5969, 0.5969]]]]])}\n", + "Transformed data shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed mask shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed data min: 0.0\n", + "Transformed data max: 459.3555603027344\n", + "Transformed data mean: 54.32514190673828\n" + ] + } + ], + "source": [ + "trafo = RadialDistortion(scale=[0.1, 1., 0.1])\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/rising/transforms/__init__.py b/rising/transforms/__init__.py index 026fe35f..27b5f141 100644 --- a/rising/transforms/__init__.py +++ b/rising/transforms/__init__.py @@ -9,3 +9,4 @@ from rising.transforms.utility import * from rising.transforms.tensor import * from rising.transforms.affine import * +from rising.transforms.grid import * diff --git a/rising/transforms/functional/crop.py b/rising/transforms/functional/crop.py index cc358c4c..121f4203 100644 --- a/rising/transforms/functional/crop.py +++ b/rising/transforms/functional/crop.py @@ -7,7 +7,8 @@ __all__ = ["crop", "center_crop", "random_crop"] -def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int]): +def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int], + exclude_last: bool = False): """ Extract crop from last dimensions of data @@ -19,6 +20,9 @@ def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int]): top left corner point size: Sequence[int] size of patch + exclude_last: bool + exclude last dimension from cropping (this is used to crop grids + from :class:`GridTransform` Returns ------- @@ -26,15 +30,20 @@ def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int]): cropped data """ _slices = [] - if len(corner) < data.ndim: - for i in range(data.ndim - len(corner)): + ndim = data.ndimension() - int(bool(exclude_last)) + if len(corner) < ndim: + for i in range(ndim - len(corner)): _slices.append(slice(0, data.shape[i])) _slices = _slices + [slice(c, c + s) for c, s in zip(corner, size)] + if exclude_last: + _slices.append(_slices.append(slice(0, data.shape[-1]))) + print(_slices) return data[_slices] -def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]]) -> torch.Tensor: +def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]], + exclude_last: bool = False) -> torch.Tensor: """ Crop patch from center @@ -44,6 +53,9 @@ def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]]) -> torch.Te input tensor size: Union[int, Sequence[int]] size of patch + exclude_last: bool + exclude last dimension from cropping (this is used to crop grids + from :class:`GridTransform` Returns ------- @@ -55,12 +67,19 @@ def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]]) -> torch.Te if not isinstance(size[0], int): size = [int(s) for s in size] - corner = [int(round((img_dim - crop_dim) / 2.)) for img_dim, crop_dim in zip(data.shape[2:], size)] + if exclude_last: + size = size[:-1] + data_shape = data.shape[2:-1] + else: + data_shape = data.shape[2:] + + corner = [int(round((img_dim - crop_dim) / 2.)) for img_dim, crop_dim in zip(data_shape, size)] return crop(data, corner, size) def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]], - dist: Union[int, Sequence[int]] = 0) -> torch.Tensor: + dist: Union[int, Sequence[int]] = 0, + exclude_last: bool = False) -> torch.Tensor: """ Crop random patch/volume from input tensor @@ -72,6 +91,9 @@ def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]], size of patch/volume dist: Union[int, Sequence[int]] minimum distance to border. By default zero + exclude_last: bool + exclude last dimension from cropping (this is used to crop grids + from :class:`GridTransform` Returns ------- @@ -85,9 +107,15 @@ def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]], if not isinstance(size[0], int): size = [int(s) for s in size] - if any([crop_dim + dist_dim >= img_dim for img_dim, crop_dim, dist_dim in zip(data.shape[2:], size, dist)]): + if exclude_last: + size = size[:-1] + data_shape = data.shape[2:-1] + else: + data_shape = data.shape[2:] + + if any([crop_dim + dist_dim >= img_dim for img_dim, crop_dim, dist_dim in zip(data_shape, size, dist)]): raise TypeError(f"Crop can not be realized with given size {size} and dist {dist}.") corner = [random.randrange(0, img_dim - crop_dim - dist_dim) for - img_dim, crop_dim, dist_dim in zip(data.shape[2:], size, dist)] + img_dim, crop_dim, dist_dim in zip(data_shape, size, dist)] return crop(data, corner, size) diff --git a/rising/transforms/grid.py b/rising/transforms/grid.py index 5ae5dd52..d3174d44 100644 --- a/rising/transforms/grid.py +++ b/rising/transforms/grid.py @@ -7,9 +7,11 @@ from rising.transforms import AbstractTransform from rising.utils.affine import get_batched_eye, matrix_to_homogeneous +from rising.transforms.functional import center_crop, random_crop, add_noise -__all__ = ["GridTransform", "StackedGridTransform"] +__all__ = ["GridTransform", "StackedGridTransform", + "CenterCropGrid", "RandomCropGrid", "RandomDistortion", "RadialDistortion"] class GridTransform(AbstractTransform): @@ -42,8 +44,7 @@ def forward(self, **data) -> dict: data[key] = torch.nn.functional.grid_sample( data[key], _grid, mode=self.interpolation_mode, - padding_mode=self.padding_mode, align_corners=self.align_corners, - **self.kwargs) + padding_mode=self.padding_mode, align_corners=self.align_corners) self.grid = None return data @@ -61,7 +62,7 @@ def create_grid(self, input_size: Sequence[Sequence[int]], for size in input_size: if tuple(size) not in grid: grid[tuple(size)] = torch.nn.functional.affine_grid( - matrix, size=input_size, align_corners=self.align_corners) + matrix, size=size, align_corners=self.align_corners) return grid def __add__(self, other): @@ -92,3 +93,95 @@ def augment_grid(self, grid: Tensor) -> Tensor: for transform in self.transforms: grid = transform.augment_grid(grid) return grid + + +class CenterCropGrid(GridTransform): + def __init__(self, + size: Union[int, Sequence[int]], + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False, + **kwargs,): + super().__init__(keys=keys, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + grad=grad, **kwargs) + self.size = size + + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + return {key: center_crop(item, size=self.size, exclude_last=True) + for key, item in grid.items()} + + +class RandomCropGrid(GridTransform): + def __init__(self, + size: Union[int, Sequence[int]], + dist: Union[int, Sequence[int]] = 0, + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False, + **kwargs,): + super().__init__(keys=keys, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + grad=grad, **kwargs) + self.size = size + self.dist = dist + + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + return {key: random_crop(item, size=self.size, dist=self.dist, exclude_last=True) + for key, item in grid.items()} + + +class RandomDistortion(GridTransform): + def __init__(self, + noise_type: str, + noise_kwargs: dict = None, + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False, + **kwargs,): + super().__init__(keys=keys, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + grad=grad, **kwargs) + self.noise_type = noise_type + self.noise_kwargs = noise_kwargs if noise_kwargs is not None else {} + + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + return {key: add_noise(item, noise_type=self.noise_type, **self.noise_kwargs) + for key, item in grid.items()} + + +class RadialDistortion(GridTransform): + def __init__(self, + scale: float, + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False, + **kwargs,): + super().__init__(keys=keys, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + grad=grad, **kwargs) + self.scale = scale + + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + + new_grid = {key: radial_distortion_grid(item, scale=self.scale) + for key, item in grid.items()} + print(new_grid) + return new_grid + + +def radial_distortion_grid(grid: Tensor, scale: float) -> Tensor: + dist = torch.norm(grid, 2, dim=-1, keepdim=True) + dist = dist / dist.max() + distortion = (scale[0] * dist.pow(3) + scale[1] * dist.pow(2) + scale[2] * dist) / 3 + print(distortion.max()) + print(distortion.min()) + return grid * (1 - distortion) From c20b42fda01374f0e6d8479f6f46946bbe041f32 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 23 Feb 2020 21:31:04 +0100 Subject: [PATCH 19/20] center and random crop for grids --- rising/transforms/functional/crop.py | 57 +++++++++++++++------------- rising/transforms/grid.py | 11 +++++- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/rising/transforms/functional/crop.py b/rising/transforms/functional/crop.py index 121f4203..2c47cbc2 100644 --- a/rising/transforms/functional/crop.py +++ b/rising/transforms/functional/crop.py @@ -8,21 +8,23 @@ def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int], - exclude_last: bool = False): + grid_crop: bool = False): """ Extract crop from last dimensions of data Parameters ---------- data: torch.Tensor - input tensor + input tensor [... , spatial dims] spatial dims can be arbitrary + spatial dimensions. Leading dimensions will be preserved. corner: Sequence[int] top left corner point size: Sequence[int] size of patch - exclude_last: bool - exclude last dimension from cropping (this is used to crop grids - from :class:`GridTransform` + grid_crop: bool + crop from grid of shape [N, spatial dims, NDIM], where N is the batch + size, spatial dims can be arbitrary spatial dimensions and NDIM + is the number of spatial dimensions Returns ------- @@ -30,32 +32,33 @@ def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int], cropped data """ _slices = [] - ndim = data.ndimension() - int(bool(exclude_last)) + ndim = data.ndimension() - int(bool(grid_crop)) if len(corner) < ndim: for i in range(ndim - len(corner)): _slices.append(slice(0, data.shape[i])) _slices = _slices + [slice(c, c + s) for c, s in zip(corner, size)] - if exclude_last: - _slices.append(_slices.append(slice(0, data.shape[-1]))) - print(_slices) + if grid_crop: + _slices.append(slice(0, data.shape[-1])) return data[_slices] def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]], - exclude_last: bool = False) -> torch.Tensor: + grid_crop: bool = False) -> torch.Tensor: """ Crop patch from center Parameters ---------- data: torch.Tensor - input tensor + input tensor [... , spatial dims] spatial dims can be arbitrary + spatial dimensions. Leading dimensions will be preserved. size: Union[int, Sequence[int]] size of patch - exclude_last: bool - exclude last dimension from cropping (this is used to crop grids - from :class:`GridTransform` + grid_crop: bool + crop from grid of shape [N, spatial dims, NDIM], where N is the batch + size, spatial dims can be arbitrary spatial dimensions and NDIM + is the number of spatial dimensions Returns ------- @@ -67,33 +70,34 @@ def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]], if not isinstance(size[0], int): size = [int(s) for s in size] - if exclude_last: - size = size[:-1] - data_shape = data.shape[2:-1] + if grid_crop: + data_shape = data.shape[1:-1] else: data_shape = data.shape[2:] corner = [int(round((img_dim - crop_dim) / 2.)) for img_dim, crop_dim in zip(data_shape, size)] - return crop(data, corner, size) + return crop(data, corner, size, grid_crop=grid_crop) def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]], dist: Union[int, Sequence[int]] = 0, - exclude_last: bool = False) -> torch.Tensor: + grid_crop: bool = False) -> torch.Tensor: """ Crop random patch/volume from input tensor Parameters ---------- data: torch.Tensor - input tensor + input tensor [... , spatial dims] spatial dims can be arbitrary + spatial dimensions. Leading dimensions will be preserved. size: Union[int, Sequence[int]] size of patch/volume dist: Union[int, Sequence[int]] minimum distance to border. By default zero - exclude_last: bool - exclude last dimension from cropping (this is used to crop grids - from :class:`GridTransform` + grid_crop: bool + crop from grid of shape [N, spatial dims, NDIM], where N is the batch + size, spatial dims can be arbitrary spatial dimensions and NDIM + is the number of spatial dimensions Returns ------- @@ -107,9 +111,8 @@ def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]], if not isinstance(size[0], int): size = [int(s) for s in size] - if exclude_last: - size = size[:-1] - data_shape = data.shape[2:-1] + if grid_crop: + data_shape = data.shape[1:-1] else: data_shape = data.shape[2:] @@ -118,4 +121,4 @@ def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]], corner = [random.randrange(0, img_dim - crop_dim - dist_dim) for img_dim, crop_dim, dist_dim in zip(data_shape, size, dist)] - return crop(data, corner, size) + return crop(data, corner, size, grid_crop=grid_crop) diff --git a/rising/transforms/grid.py b/rising/transforms/grid.py index d3174d44..2bc58db3 100644 --- a/rising/transforms/grid.py +++ b/rising/transforms/grid.py @@ -110,7 +110,7 @@ def __init__(self, self.size = size def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: - return {key: center_crop(item, size=self.size, exclude_last=True) + return {key: center_crop(item, size=self.size, grid_crop=True) for key, item in grid.items()} @@ -131,7 +131,7 @@ def __init__(self, self.dist = dist def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: - return {key: random_crop(item, size=self.size, dist=self.dist, exclude_last=True) + return {key: random_crop(item, size=self.size, dist=self.dist, grid_crop=True) for key, item in grid.items()} @@ -179,6 +179,13 @@ def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: def radial_distortion_grid(grid: Tensor, scale: float) -> Tensor: + # spatial_shape = grid.shape[1:-1] + # new_grid = torch.stack([torch.meshgrid( + # *[torch.linspace(-1, 1, i) for i in spatial_shape])], dim=-1).to(grid) + # print(new_grid.shape) + # + # distortion = + dist = torch.norm(grid, 2, dim=-1, keepdim=True) dist = dist / dist.max() distortion = (scale[0] * dist.pow(3) + scale[1] * dist.pow(2) + scale[2] * dist) / 3 From 464e992defe19b1af908158b3c369bd3794f45a9 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 23 Feb 2020 22:28:30 +0100 Subject: [PATCH 20/20] elastic deform prototyping --- notebooks/transformations.ipynb | 334 ++--------------------- rising/transforms/functional/__init__.py | 1 + rising/transforms/functional/kernel.py | 33 +++ rising/transforms/grid.py | 30 +- rising/transforms/kernel.py | 32 +-- 5 files changed, 83 insertions(+), 347 deletions(-) create mode 100644 rising/transforms/functional/kernel.py diff --git a/notebooks/transformations.ipynb b/notebooks/transformations.ipynb index 94407d8a..d5e313bd 100644 --- a/notebooks/transformations.ipynb +++ b/notebooks/transformations.ipynb @@ -261,7 +261,7 @@ "metadata": {}, "outputs": [], "source": [ - "trafo = CenterCropGrid(100)\n", + "trafo = CenterCropGrid(size=100)\n", "transformed = apply_transform(trafo, batch)\n", "view_batch(transformed)" ] @@ -272,337 +272,43 @@ "metadata": {}, "outputs": [], "source": [ - "trafo = RandomDistortion(noise_type=\"normal\", noise_kwargs={\"mean\": 0.1, \"std\": 0.005})\n", + "trafo = RandomCropGrid(size=100)\n", "transformed = apply_transform(trafo, batch)\n", "view_batch(transformed)" ] }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor(0.4000)\n", - "tensor(0.0002)\n", - "{(1, 1, 192, 192, 174): tensor([[[[[-0.5966, -0.5969, -0.5969],\n", - " [-0.5927, -0.5999, -0.5999],\n", - " [-0.5887, -0.6029, -0.6029],\n", - " ...,\n", - " [ 0.5887, -0.6029, -0.6029],\n", - " [ 0.5927, -0.5999, -0.5999],\n", - " [ 0.5966, -0.5969, -0.5969]],\n", - "\n", - " [[-0.5993, -0.5934, -0.5996],\n", - " [-0.5954, -0.5964, -0.6027],\n", - " [-0.5914, -0.5994, -0.6057],\n", - " ...,\n", - " [ 0.5914, -0.5994, -0.6057],\n", - " [ 0.5954, -0.5964, -0.6027],\n", - " [ 0.5993, -0.5934, -0.5996]],\n", - "\n", - " [[-0.6020, -0.5898, -0.6024],\n", - " [-0.5981, -0.5927, -0.6054],\n", - " [-0.5940, -0.5957, -0.6084],\n", - " ...,\n", - " [ 0.5940, -0.5957, -0.6084],\n", - " [ 0.5981, -0.5927, -0.6054],\n", - " [ 0.6020, -0.5898, -0.6024]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.6020, 0.5898, -0.6024],\n", - " [-0.5981, 0.5927, -0.6054],\n", - " [-0.5940, 0.5957, -0.6084],\n", - " ...,\n", - " [ 0.5940, 0.5957, -0.6084],\n", - " [ 0.5981, 0.5927, -0.6054],\n", - " [ 0.6020, 0.5898, -0.6024]],\n", - "\n", - " [[-0.5993, 0.5934, -0.5996],\n", - " [-0.5954, 0.5964, -0.6027],\n", - " [-0.5914, 0.5994, -0.6057],\n", - " ...,\n", - " [ 0.5914, 0.5994, -0.6057],\n", - " [ 0.5954, 0.5964, -0.6027],\n", - " [ 0.5993, 0.5934, -0.5996]],\n", - "\n", - " [[-0.5966, 0.5969, -0.5969],\n", - " [-0.5927, 0.5999, -0.5999],\n", - " [-0.5887, 0.6029, -0.6029],\n", - " ...,\n", - " [ 0.5887, 0.6029, -0.6029],\n", - " [ 0.5927, 0.5999, -0.5999],\n", - " [ 0.5966, 0.5969, -0.5969]]],\n", - "\n", - "\n", - " [[[-0.5993, -0.5996, -0.5934],\n", - " [-0.5954, -0.6027, -0.5964],\n", - " [-0.5914, -0.6057, -0.5994],\n", - " ...,\n", - " [ 0.5914, -0.6057, -0.5994],\n", - " [ 0.5954, -0.6027, -0.5964],\n", - " [ 0.5993, -0.5996, -0.5934]],\n", - "\n", - " [[-0.6021, -0.5961, -0.5961],\n", - " [-0.5981, -0.5991, -0.5991],\n", - " [-0.5941, -0.6021, -0.6021],\n", - " ...,\n", - " [ 0.5941, -0.6021, -0.6021],\n", - " [ 0.5981, -0.5991, -0.5991],\n", - " [ 0.6021, -0.5961, -0.5961]],\n", - "\n", - " [[-0.6048, -0.5925, -0.5988],\n", - " [-0.6008, -0.5954, -0.6018],\n", - " [-0.5967, -0.5984, -0.6048],\n", - " ...,\n", - " [ 0.5967, -0.5984, -0.6048],\n", - " [ 0.6008, -0.5954, -0.6018],\n", - " [ 0.6048, -0.5925, -0.5988]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.6048, 0.5925, -0.5988],\n", - " [-0.6008, 0.5954, -0.6018],\n", - " [-0.5967, 0.5984, -0.6048],\n", - " ...,\n", - " [ 0.5967, 0.5984, -0.6048],\n", - " [ 0.6008, 0.5954, -0.6018],\n", - " [ 0.6048, 0.5925, -0.5988]],\n", - "\n", - " [[-0.6021, 0.5961, -0.5961],\n", - " [-0.5981, 0.5991, -0.5991],\n", - " [-0.5941, 0.6021, -0.6021],\n", - " ...,\n", - " [ 0.5941, 0.6021, -0.6021],\n", - " [ 0.5981, 0.5991, -0.5991],\n", - " [ 0.6021, 0.5961, -0.5961]],\n", - "\n", - " [[-0.5993, 0.5996, -0.5934],\n", - " [-0.5954, 0.6027, -0.5964],\n", - " [-0.5914, 0.6057, -0.5994],\n", - " ...,\n", - " [ 0.5914, 0.6057, -0.5994],\n", - " [ 0.5954, 0.6027, -0.5964],\n", - " [ 0.5993, 0.5996, -0.5934]]],\n", - "\n", - "\n", - " [[[-0.6020, -0.6024, -0.5898],\n", - " [-0.5981, -0.6054, -0.5927],\n", - " [-0.5940, -0.6084, -0.5957],\n", - " ...,\n", - " [ 0.5940, -0.6084, -0.5957],\n", - " [ 0.5981, -0.6054, -0.5927],\n", - " [ 0.6020, -0.6024, -0.5898]],\n", - "\n", - " [[-0.6048, -0.5988, -0.5925],\n", - " [-0.6008, -0.6018, -0.5954],\n", - " [-0.5967, -0.6048, -0.5984],\n", - " ...,\n", - " [ 0.5967, -0.6048, -0.5984],\n", - " [ 0.6008, -0.6018, -0.5954],\n", - " [ 0.6048, -0.5988, -0.5925]],\n", - "\n", - " [[-0.6075, -0.5951, -0.5951],\n", - " [-0.6035, -0.5981, -0.5981],\n", - " [-0.5994, -0.6011, -0.6011],\n", - " ...,\n", - " [ 0.5994, -0.6011, -0.6011],\n", - " [ 0.6035, -0.5981, -0.5981],\n", - " [ 0.6075, -0.5951, -0.5951]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.6075, 0.5951, -0.5951],\n", - " [-0.6035, 0.5981, -0.5981],\n", - " [-0.5994, 0.6011, -0.6011],\n", - " ...,\n", - " [ 0.5994, 0.6011, -0.6011],\n", - " [ 0.6035, 0.5981, -0.5981],\n", - " [ 0.6075, 0.5951, -0.5951]],\n", - "\n", - " [[-0.6048, 0.5988, -0.5925],\n", - " [-0.6008, 0.6018, -0.5954],\n", - " [-0.5967, 0.6048, -0.5984],\n", - " ...,\n", - " [ 0.5967, 0.6048, -0.5984],\n", - " [ 0.6008, 0.6018, -0.5954],\n", - " [ 0.6048, 0.5988, -0.5925]],\n", - "\n", - " [[-0.6020, 0.6024, -0.5898],\n", - " [-0.5981, 0.6054, -0.5927],\n", - " [-0.5940, 0.6084, -0.5957],\n", - " ...,\n", - " [ 0.5940, 0.6084, -0.5957],\n", - " [ 0.5981, 0.6054, -0.5927],\n", - " [ 0.6020, 0.6024, -0.5898]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-0.6020, -0.6024, 0.5898],\n", - " [-0.5981, -0.6054, 0.5927],\n", - " [-0.5940, -0.6084, 0.5957],\n", - " ...,\n", - " [ 0.5940, -0.6084, 0.5957],\n", - " [ 0.5981, -0.6054, 0.5927],\n", - " [ 0.6020, -0.6024, 0.5898]],\n", - "\n", - " [[-0.6048, -0.5988, 0.5925],\n", - " [-0.6008, -0.6018, 0.5954],\n", - " [-0.5967, -0.6048, 0.5984],\n", - " ...,\n", - " [ 0.5967, -0.6048, 0.5984],\n", - " [ 0.6008, -0.6018, 0.5954],\n", - " [ 0.6048, -0.5988, 0.5925]],\n", - "\n", - " [[-0.6075, -0.5951, 0.5951],\n", - " [-0.6035, -0.5981, 0.5981],\n", - " [-0.5994, -0.6011, 0.6011],\n", - " ...,\n", - " [ 0.5994, -0.6011, 0.6011],\n", - " [ 0.6035, -0.5981, 0.5981],\n", - " [ 0.6075, -0.5951, 0.5951]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.6075, 0.5951, 0.5951],\n", - " [-0.6035, 0.5981, 0.5981],\n", - " [-0.5994, 0.6011, 0.6011],\n", - " ...,\n", - " [ 0.5994, 0.6011, 0.6011],\n", - " [ 0.6035, 0.5981, 0.5981],\n", - " [ 0.6075, 0.5951, 0.5951]],\n", - "\n", - " [[-0.6048, 0.5988, 0.5925],\n", - " [-0.6008, 0.6018, 0.5954],\n", - " [-0.5967, 0.6048, 0.5984],\n", - " ...,\n", - " [ 0.5967, 0.6048, 0.5984],\n", - " [ 0.6008, 0.6018, 0.5954],\n", - " [ 0.6048, 0.5988, 0.5925]],\n", - "\n", - " [[-0.6020, 0.6024, 0.5898],\n", - " [-0.5981, 0.6054, 0.5927],\n", - " [-0.5940, 0.6084, 0.5957],\n", - " ...,\n", - " [ 0.5940, 0.6084, 0.5957],\n", - " [ 0.5981, 0.6054, 0.5927],\n", - " [ 0.6020, 0.6024, 0.5898]]],\n", - "\n", - "\n", - " [[[-0.5993, -0.5996, 0.5934],\n", - " [-0.5954, -0.6027, 0.5964],\n", - " [-0.5914, -0.6057, 0.5994],\n", - " ...,\n", - " [ 0.5914, -0.6057, 0.5994],\n", - " [ 0.5954, -0.6027, 0.5964],\n", - " [ 0.5993, -0.5996, 0.5934]],\n", - "\n", - " [[-0.6021, -0.5961, 0.5961],\n", - " [-0.5981, -0.5991, 0.5991],\n", - " [-0.5941, -0.6021, 0.6021],\n", - " ...,\n", - " [ 0.5941, -0.6021, 0.6021],\n", - " [ 0.5981, -0.5991, 0.5991],\n", - " [ 0.6021, -0.5961, 0.5961]],\n", - "\n", - " [[-0.6048, -0.5925, 0.5988],\n", - " [-0.6008, -0.5954, 0.6018],\n", - " [-0.5967, -0.5984, 0.6048],\n", - " ...,\n", - " [ 0.5967, -0.5984, 0.6048],\n", - " [ 0.6008, -0.5954, 0.6018],\n", - " [ 0.6048, -0.5925, 0.5988]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.6048, 0.5925, 0.5988],\n", - " [-0.6008, 0.5954, 0.6018],\n", - " [-0.5967, 0.5984, 0.6048],\n", - " ...,\n", - " [ 0.5967, 0.5984, 0.6048],\n", - " [ 0.6008, 0.5954, 0.6018],\n", - " [ 0.6048, 0.5925, 0.5988]],\n", - "\n", - " [[-0.6021, 0.5961, 0.5961],\n", - " [-0.5981, 0.5991, 0.5991],\n", - " [-0.5941, 0.6021, 0.6021],\n", - " ...,\n", - " [ 0.5941, 0.6021, 0.6021],\n", - " [ 0.5981, 0.5991, 0.5991],\n", - " [ 0.6021, 0.5961, 0.5961]],\n", - "\n", - " [[-0.5993, 0.5996, 0.5934],\n", - " [-0.5954, 0.6027, 0.5964],\n", - " [-0.5914, 0.6057, 0.5994],\n", - " ...,\n", - " [ 0.5914, 0.6057, 0.5994],\n", - " [ 0.5954, 0.6027, 0.5964],\n", - " [ 0.5993, 0.5996, 0.5934]]],\n", - "\n", - "\n", - " [[[-0.5966, -0.5969, 0.5969],\n", - " [-0.5927, -0.5999, 0.5999],\n", - " [-0.5887, -0.6029, 0.6029],\n", - " ...,\n", - " [ 0.5887, -0.6029, 0.6029],\n", - " [ 0.5927, -0.5999, 0.5999],\n", - " [ 0.5966, -0.5969, 0.5969]],\n", - "\n", - " [[-0.5993, -0.5934, 0.5996],\n", - " [-0.5954, -0.5964, 0.6027],\n", - " [-0.5914, -0.5994, 0.6057],\n", - " ...,\n", - " [ 0.5914, -0.5994, 0.6057],\n", - " [ 0.5954, -0.5964, 0.6027],\n", - " [ 0.5993, -0.5934, 0.5996]],\n", - "\n", - " [[-0.6020, -0.5898, 0.6024],\n", - " [-0.5981, -0.5927, 0.6054],\n", - " [-0.5940, -0.5957, 0.6084],\n", - " ...,\n", - " [ 0.5940, -0.5957, 0.6084],\n", - " [ 0.5981, -0.5927, 0.6054],\n", - " [ 0.6020, -0.5898, 0.6024]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.6020, 0.5898, 0.6024],\n", - " [-0.5981, 0.5927, 0.6054],\n", - " [-0.5940, 0.5957, 0.6084],\n", - " ...,\n", - " [ 0.5940, 0.5957, 0.6084],\n", - " [ 0.5981, 0.5927, 0.6054],\n", - " [ 0.6020, 0.5898, 0.6024]],\n", - "\n", - " [[-0.5993, 0.5934, 0.5996],\n", - " [-0.5954, 0.5964, 0.6027],\n", - " [-0.5914, 0.5994, 0.6057],\n", - " ...,\n", - " [ 0.5914, 0.5994, 0.6057],\n", - " [ 0.5954, 0.5964, 0.6027],\n", - " [ 0.5993, 0.5934, 0.5996]],\n", - "\n", - " [[-0.5966, 0.5969, 0.5969],\n", - " [-0.5927, 0.5999, 0.5999],\n", - " [-0.5887, 0.6029, 0.6029],\n", - " ...,\n", - " [ 0.5887, 0.6029, 0.6029],\n", - " [ 0.5927, 0.5999, 0.5999],\n", - " [ 0.5966, 0.5969, 0.5969]]]]])}\n", + "torch.Size([1, 1, 192, 192, 174])\n", + "torch.Size([1, 192, 192, 174, 3])\n", + "tensor(0.0100)\n", + "tensor(-0.0100)\n", "Transformed data shape: torch.Size([1, 1, 192, 192, 174])\n", "Transformed mask shape: torch.Size([1, 1, 192, 192, 174])\n", "Transformed data min: 0.0\n", - "Transformed data max: 459.3555603027344\n", - "Transformed data mean: 54.32514190673828\n" + "Transformed data max: 474.02880859375\n", + "Transformed data mean: 37.61506652832031\n" ] } ], + "source": [ + "trafo = ElasticDistortion(alpha=0.01, std=[0.1, 0.1, 0.000001], dim=3)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "trafo = RadialDistortion(scale=[0.1, 1., 0.1])\n", "transformed = apply_transform(trafo, batch)\n", diff --git a/rising/transforms/functional/__init__.py b/rising/transforms/functional/__init__.py index b1a5aabe..1117a081 100644 --- a/rising/transforms/functional/__init__.py +++ b/rising/transforms/functional/__init__.py @@ -5,3 +5,4 @@ from rising.transforms.functional.tensor import * from rising.transforms.functional.utility import * from rising.transforms.functional.channel import * +from rising.transforms.functional.kernel import * diff --git a/rising/transforms/functional/kernel.py b/rising/transforms/functional/kernel.py new file mode 100644 index 00000000..e287bd41 --- /dev/null +++ b/rising/transforms/functional/kernel.py @@ -0,0 +1,33 @@ +import math +import torch + +from typing import Sequence, Union + +from rising.utils import check_scalar + +__all__ = ["gaussian_kernel"] + + +def gaussian_kernel(dim: int, kernel_size: Union[int, Sequence[int]], + std: Union[float, Sequence[float]], in_channels: int = 1) -> torch.Tensor: + if check_scalar(kernel_size): + kernel_size = [kernel_size] * dim + if check_scalar(std): + std = [std] * dim + # The gaussian kernel is the product of the gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) + for size in kernel_size]) + + for size, std, mgrid in zip(kernel_size, std, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / std) ** 2 / 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / kernel.sum() + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(in_channels, *[1] * (kernel.dim() - 1)) + kernel.requires_grad = False + return kernel.contiguous() diff --git a/rising/transforms/grid.py b/rising/transforms/grid.py index 2bc58db3..961fd4d6 100644 --- a/rising/transforms/grid.py +++ b/rising/transforms/grid.py @@ -5,13 +5,13 @@ from abc import abstractmethod from torch import Tensor -from rising.transforms import AbstractTransform +from rising.transforms import AbstractTransform, GaussianSmoothing from rising.utils.affine import get_batched_eye, matrix_to_homogeneous -from rising.transforms.functional import center_crop, random_crop, add_noise +from rising.transforms.functional import center_crop, random_crop __all__ = ["GridTransform", "StackedGridTransform", - "CenterCropGrid", "RandomCropGrid", "RandomDistortion", "RadialDistortion"] + "CenterCropGrid", "RandomCropGrid", "ElasticDistortion", "RadialDistortion"] class GridTransform(AbstractTransform): @@ -135,10 +135,11 @@ def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: for key, item in grid.items()} -class RandomDistortion(GridTransform): +class ElasticDistortion(GridTransform): def __init__(self, - noise_type: str, - noise_kwargs: dict = None, + std: Union[float, Sequence[float]], + alpha: float, + dim: int = 2, keys: Sequence[str] = ('data',), interpolation_mode: str = 'bilinear', padding_mode: str = 'zeros', @@ -148,12 +149,21 @@ def __init__(self, super().__init__(keys=keys, interpolation_mode=interpolation_mode, padding_mode=padding_mode, align_corners=align_corners, grad=grad, **kwargs) - self.noise_type = noise_type - self.noise_kwargs = noise_kwargs if noise_kwargs is not None else {} + self.std = std + self.alpha = alpha + self.gaussian = GaussianSmoothing(in_channels=1, kernel_size=7, std=self.std, + dim=dim, stride=1, padding=3) def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: - return {key: add_noise(item, noise_type=self.noise_type, **self.noise_kwargs) - for key, item in grid.items()} + for key in grid.keys(): + random_offsets = torch.rand(1, 1, *grid[key].shape[1:-1]) * 2 - 1 + random_offsets = self.gaussian(**{"data": random_offsets})["data"] * self.alpha + print(random_offsets.shape) + print(grid[key].shape) + print(random_offsets.max()) + print(random_offsets.min()) + grid[key] += random_offsets[:, 0, ..., None] + return grid class RadialDistortion(GridTransform): diff --git a/rising/transforms/kernel.py b/rising/transforms/kernel.py index 5937e5d5..f6730aa1 100644 --- a/rising/transforms/kernel.py +++ b/rising/transforms/kernel.py @@ -4,6 +4,7 @@ from .abstract import AbstractTransform from rising.utils import check_scalar +from rising.transforms.functional import gaussian_kernel __all__ = ["KernelTransform", "GaussianSmoothing"] @@ -111,10 +112,12 @@ def forward(self, **data) -> dict: class GaussianSmoothing(KernelTransform): - def __init__(self, in_channels: int, kernel_size: Union[int, Sequence], + def __init__(self, + in_channels: int, + kernel_size: Union[int, Sequence], std: Union[int, Sequence], dim: int = 2, stride: Union[int, Sequence] = 1, padding: Union[int, Sequence] = 0, - padding_mode: str = 'reflect', keys: Sequence = ('data',), grad: bool = False, + padding_mode: str = 'constant', keys: Sequence = ('data',), grad: bool = False, **kwargs): """ Perform Gaussian Smoothing. @@ -150,9 +153,8 @@ def __init__(self, in_channels: int, kernel_size: Union[int, Sequence], -------- :func:`torch.functional.pad` """ - if check_scalar(std): - std = [std] * dim self.std = std + self.spatial_dim = dim super().__init__(in_channels=in_channels, kernel_size=kernel_size, dim=dim, stride=stride, padding=padding, padding_mode=padding_mode, keys=keys, grad=grad, **kwargs) @@ -160,22 +162,6 @@ def create_kernel(self) -> torch.Tensor: """ Create gaussian blur kernel """ - # The gaussian kernel is the product of the gaussian function of each dimension. - kernel = 1 - meshgrids = torch.meshgrid([ - torch.arange(size, dtype=torch.float32) - for size in self.kernel_size - ]) - - for size, std, mgrid in zip(self.kernel_size, self.std, meshgrids): - mean = (size - 1) / 2 - kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / std) ** 2 / 2) - - # Make sure sum of values in gaussian kernel equals 1. - kernel = kernel / kernel.sum() - - # Reshape to depthwise convolutional weight - kernel = kernel.view(1, 1, *kernel.size()) - kernel = kernel.repeat(self.in_channels, *[1] * (kernel.dim() - 1)) - kernel.requires_grad = False - return kernel.contiguous() + return gaussian_kernel(kernel_size=self.kernel_size, + std=self.std, in_channels=self.in_channels, + dim=self.spatial_dim)